diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -21,11 +21,11 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) { switch (kind) { - case Kind::kTensor: + case kTensor: assert(x != -1u && y == -1u && !v); tensor = x; break; - case Kind::kInvariant: + case kInvariant: assert(x == -1u && y == -1u && v); break; case kAbsF: @@ -99,10 +99,10 @@ for (unsigned p : latSets[s0]) latSets[s].push_back(p); // Map binary 0-y to unary -y. - if (kind == Kind::kSubF) - s1 = mapSet(Kind::kNegF, s1); - else if (kind == Kind::kSubI) - s1 = mapSet(Kind::kNegI, s1); + if (kind == kSubF) + s1 = mapSet(kNegF, s1); + else if (kind == kSubI) + s1 = mapSet(kNegI, s1); // Followed by all in s1. for (unsigned p : latSets[s1]) latSets[s].push_back(p); @@ -110,7 +110,7 @@ } unsigned Merger::mapSet(Kind kind, unsigned s0) { - assert(Kind::kAbsF <= kind && kind <= Kind::kNegI); + assert(kAbsF <= kind && kind <= kNegI); unsigned s = addSet(); for (unsigned p : latSets[s0]) { unsigned e = addExp(kind, latPoints[p].exp); @@ -129,8 +129,7 @@ if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (tensorExps[e].kind == Kind::kTensor && - tensorExps[e].tensor == outTensor) + if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -162,9 +161,9 @@ } // Now apply the two basic rules. llvm::BitVector simple = latPoints[p0].bits; - bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); + bool reset = isSingleton && hasAnyDimOf(simple, kSparse); for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && !isDim(b, Dim::kSparse)) { + if (simple[b] && !isDim(b, kSparse)) { if (reset) simple.reset(b); reset = true; @@ -189,7 +188,7 @@ bool Merger::onlyDenseDiff(unsigned i, unsigned j) { llvm::BitVector tmp = latPoints[j].bits; tmp ^= latPoints[i].bits; - return !hasAnyDimOf(tmp, Dim::kSparse); + return !hasAnyDimOf(tmp, kSparse); } bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { @@ -201,23 +200,27 @@ bool Merger::isConjunction(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { - case Kind::kTensor: + case kTensor: return tensorExps[e].tensor == t; case kAbsF: case kCeilF: case kFloorF: case kNegF: case kNegI: - case Kind::kDivF: // note: x / c only - case Kind::kDivS: - case Kind::kDivU: - case Kind::kShrS: // note: x >> inv only - case Kind::kShrU: - case Kind::kShlI: return isConjunction(t, tensorExps[e].children.e0); - case Kind::kMulF: - case Kind::kMulI: - case Kind::kAndI: + case kDivF: // note: x / c only + case kDivS: + case kDivU: + assert(!maybeZero(tensorExps[e].children.e1)); + return isConjunction(t, tensorExps[e].children.e0); + case kShrS: // note: x >> inv only + case kShrU: + case kShlI: + assert(!isInvariant(tensorExps[e].children.e1)); + return isConjunction(t, tensorExps[e].children.e0); + case kMulF: + case kMulI: + case kAndI: return isConjunction(t, tensorExps[e].children.e0) || isConjunction(t, tensorExps[e].children.e1); default: @@ -231,20 +234,66 @@ // Print methods (for debugging). // -static const char *kOpSymbols[] = { - "", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/", - "+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"}; +static const char *kindToOpSymbol(Kind kind) { + switch (kind) { + case kTensor: + return "tensor"; + case kInvariant: + return "invariant"; + case kAbsF: + return "abs"; + case kCeilF: + return "ceil"; + case kFloorF: + return "floor"; + case kNegF: + return "-"; + case kNegI: + return "-"; + case kMulF: + return "*"; + case kMulI: + return "*"; + case kDivF: + return "/"; + case kDivS: + return "/"; + case kDivU: + return "/"; + case kAddF: + return "+"; + case kAddI: + return "+"; + case kSubF: + return "-"; + case kSubI: + return "-"; + case kAndI: + return "&"; + case kOrI: + return "|"; + case kXorI: + return "^"; + case kShrS: + return "a>>"; + case kShrU: + return ">>"; + case kShlI: + return "<<"; + } + llvm_unreachable("unexpected kind for symbol"); +} void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { - case Kind::kTensor: + case kTensor: if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; else if (tensorExps[e].tensor == outTensor) llvm::dbgs() << "output_"; llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; - case Kind::kInvariant: + case kInvariant: llvm::dbgs() << "invariant"; break; case kAbsF: @@ -252,13 +301,13 @@ case kFloorF: case kNegF: case kNegI: - llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " "; + llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; default: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); - llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " "; + llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e1); llvm::dbgs() << ")"; } @@ -290,16 +339,16 @@ unsigned i = index(b); llvm::dbgs() << " i_" << t << "_" << i << "_"; switch (dims[t][i]) { - case Dim::kSparse: + case kSparse: llvm::dbgs() << "S"; break; - case Dim::kDense: + case kDense: llvm::dbgs() << "D"; break; - case Dim::kSingle: + case kSingle: llvm::dbgs() << "T"; break; - case Dim::kUndef: + case kUndef: llvm::dbgs() << "U"; break; } @@ -316,13 +365,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { Kind kind = tensorExps[e].kind; switch (kind) { - case Kind::kTensor: - case Kind::kInvariant: { + case kTensor: + case kInvariant: { // Either the index is really used in the tensor expression, or it is // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. unsigned s = addSet(); - unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor; + unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; latSets[s].push_back(addLat(t, i, e)); return s; } @@ -338,9 +387,9 @@ // --+---+---+ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i)); - case Kind::kMulF: - case Kind::kMulI: - case Kind::kAndI: + case kMulF: + case kMulI: + case kAndI: // A multiplicative operation only needs to be performed // for the conjunction of sparse iteration spaces. // @@ -351,9 +400,9 @@ return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case Kind::kDivF: - case Kind::kDivS: - case Kind::kDivU: + case kDivF: + case kDivS: + case kDivU: // A division is tricky, since 0/0, 0/c, c/0 all have // specific outcomes for floating-point and integers. // Thus, we need to traverse the full iteration space. @@ -367,15 +416,16 @@ // during expression building, so that the conjunction // rules applies (viz. x/c = x*(1/c) as far as lattice // construction is concerned). + assert(!maybeZero(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case Kind::kAddF: - case Kind::kAddI: - case Kind::kSubF: - case Kind::kSubI: - case Kind::kOrI: - case Kind::kXorI: + case kAddF: + case kAddI: + case kSubF: + case kSubI: + case kOrI: + case kXorI: // An additive operation needs to be performed // for the disjunction of sparse iteration spaces. // @@ -386,12 +436,13 @@ return takeDisj(kind, // take binary disjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case Kind::kShrS: - case Kind::kShrU: - case Kind::kShlI: + case kShrS: + case kShrU: + case kShlI: // A shift operation by an invariant amount (viz. tensor expressions // can only occur at the left-hand-side of the operator) can be handled // with the conjuction rule. + assert(isInvariant(tensorExps[e].children.e1)); return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); @@ -405,7 +456,7 @@ } bool Merger::maybeZero(unsigned e) const { - if (tensorExps[e].kind == Kind::kInvariant) { + if (tensorExps[e].kind == kInvariant) { if (auto c = tensorExps[e].val.getDefiningOp()) return c.getValue() == 0; if (auto c = tensorExps[e].val.getDefiningOp()) @@ -415,7 +466,7 @@ } bool Merger::isInvariant(unsigned e) const { - return tensorExps[e].kind == Kind::kInvariant; + return tensorExps[e].kind == kInvariant; } Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { @@ -427,30 +478,30 @@ if (arg.getOwner()->getParentOp() == op) { OpOperand *t = op.getInputAndOutputOperands()[argN]; if (!op.isScalar(t)) - return addExp(Kind::kTensor, argN); + return addExp(kTensor, argN); v = t->get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. - return addExp(Kind::kInvariant, v); + return addExp(kInvariant, v); } // Something defined outside is invariant. Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.region().front()) - return addExp(Kind::kInvariant, v); + return addExp(kInvariant, v); // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { auto x = buildTensorExp(op, def->getOperand(0)); if (x.hasValue()) { unsigned e = x.getValue(); if (isa(def)) - return addExp(Kind::kAbsF, e); + return addExp(kAbsF, e); if (isa(def)) - return addExp(Kind::kCeilF, e); + return addExp(kCeilF, e); if (isa(def)) - return addExp(Kind::kFloorF, e); + return addExp(kFloorF, e); if (isa(def)) - return addExp(Kind::kNegF, e); + return addExp(kNegF, e); // TODO: no negi in std? } } @@ -463,35 +514,35 @@ unsigned e0 = x.getValue(); unsigned e1 = y.getValue(); if (isa(def)) - return addExp(Kind::kMulF, e0, e1); + return addExp(kMulF, e0, e1); if (isa(def)) - return addExp(Kind::kMulI, e0, e1); + return addExp(kMulI, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(Kind::kDivF, e0, e1); + return addExp(kDivF, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(Kind::kDivS, e0, e1); + return addExp(kDivS, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(Kind::kDivU, e0, e1); + return addExp(kDivU, e0, e1); if (isa(def)) - return addExp(Kind::kAddF, e0, e1); + return addExp(kAddF, e0, e1); if (isa(def)) - return addExp(Kind::kAddI, e0, e1); + return addExp(kAddI, e0, e1); if (isa(def)) - return addExp(Kind::kSubF, e0, e1); + return addExp(kSubF, e0, e1); if (isa(def)) - return addExp(Kind::kSubI, e0, e1); + return addExp(kSubI, e0, e1); if (isa(def)) - return addExp(Kind::kAndI, e0, e1); + return addExp(kAndI, e0, e1); if (isa(def)) - return addExp(Kind::kOrI, e0, e1); + return addExp(kOrI, e0, e1); if (isa(def)) - return addExp(Kind::kXorI, e0, e1); + return addExp(kXorI, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(Kind::kShrS, e0, e1); + return addExp(kShrS, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(Kind::kShrU, e0, e1); + return addExp(kShrU, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(Kind::kShlI, e0, e1); + return addExp(kShlI, e0, e1); } } // Cannot build. @@ -501,8 +552,8 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { - case Kind::kTensor: - case Kind::kInvariant: + case kTensor: + case kInvariant: llvm_unreachable("unexpected non-op"); case kAbsF: return rewriter.create(loc, v0); @@ -515,35 +566,35 @@ case kNegI: assert(v1); // no negi in std return rewriter.create(loc, v0, v1); - case Kind::kMulF: + case kMulF: return rewriter.create(loc, v0, v1); - case Kind::kMulI: + case kMulI: return rewriter.create(loc, v0, v1); - case Kind::kDivF: + case kDivF: return rewriter.create(loc, v0, v1); - case Kind::kDivS: + case kDivS: return rewriter.create(loc, v0, v1); - case Kind::kDivU: + case kDivU: return rewriter.create(loc, v0, v1); - case Kind::kAddF: + case kAddF: return rewriter.create(loc, v0, v1); - case Kind::kAddI: + case kAddI: return rewriter.create(loc, v0, v1); - case Kind::kSubF: + case kSubF: return rewriter.create(loc, v0, v1); - case Kind::kSubI: + case kSubI: return rewriter.create(loc, v0, v1); - case Kind::kAndI: + case kAndI: return rewriter.create(loc, v0, v1); - case Kind::kOrI: + case kOrI: return rewriter.create(loc, v0, v1); - case Kind::kXorI: + case kXorI: return rewriter.create(loc, v0, v1); - case Kind::kShrS: + case kShrS: return rewriter.create(loc, v0, v1); - case Kind::kShrU: + case kShrU: return rewriter.create(loc, v0, v1); - case Kind::kShlI: + case kShlI: return rewriter.create(loc, v0, v1); } llvm_unreachable("unexpected expression kind in build");