diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -23,87 +23,6 @@ namespace mlir { namespace sparse_tensor { -/// Tensor expression kind. -/// -/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`. -/// That is, its argument is a `LoopId` identifying the loop-variable -/// in question, and its value will be the current iteration's value -/// of that loop-variable. See the `LoopId` documentation for more details. -// -// TODO: make this an `enum class` nested in the `TensorExp` class; -// to improve namespacing, and match the pattern used by other "Kind" -// enums in MLIR. -// -// TODO: Modify this definition so that the numeric values already encode -// the `ExpArity` (while extending the notion of "arity" to include not -// just the number of `ExprId` children the node has, but also whether the -// node has a `Value` and/or `Operation*`). Doing this will avoid needing -// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor, -// and should help clean up a few other places as well. -enum Kind { - // Leaf. - kTensor = 0, - kInvariant, - kLoopVar, - // Unary operations. - kAbsF, - kAbsC, - kAbsI, - kCeilF, - kFloorF, - kSqrtF, - kSqrtC, - kExpm1F, - kExpm1C, - kLog1pF, - kLog1pC, - kSinF, - kSinC, - kTanhF, - kTanhC, - kNegF, - kNegC, - kNegI, - kTruncF, - kExtF, - kCastFS, // signed - kCastFU, // unsigned - kCastSF, // signed - kCastUF, // unsigned - kCastS, // signed - kCastU, // unsigned - kCastIdx, - kTruncI, - kCIm, // complex.im - kCRe, // complex.re - kBitCast, - kBinaryBranch, // semiring unary branch created from a binary op - kUnary, // semiring unary op - kSelect, // custom selection criteria - // Binary operations. - kMulF, - kMulC, - kMulI, - kDivF, - kDivC, // complex - kDivS, // signed - kDivU, // unsigned - kAddF, - kAddC, - kAddI, - kSubF, - kSubC, - kSubI, - kAndI, - kOrI, - kXorI, - kShrS, // signed - kShrU, // unsigned - kShlI, - kBinary, // semiring binary op - kReduce, // semiring reduction op -}; - // TODO: These type aliases currently only serve to make the code more // self-documenting, however because they are not type-checked they can // do nothing to prevent mixups. We should really change them from mere @@ -169,6 +88,8 @@ /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { + enum class Kind; + // The `x` parameter has different types depending on the value of the // `k` parameter. The correspondences are: // * `kTensor` -> `TensorId` @@ -207,6 +128,83 @@ Operation *op; }; +/// Tensor expression kind. +/// +/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`. +/// That is, its argument is a `LoopId` identifying the loop-variable +/// in question, and its value will be the current iteration's value +/// of that loop-variable. See the `LoopId` documentation for more details. +// +// TODO: Modify this definition so that the numeric values already encode +// the `ExpArity` (while extending the notion of "arity" to include not +// just the number of `ExprId` children the node has, but also whether the +// node has a `Value` and/or `Operation*`). Doing this will avoid needing +// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor, +// and should help clean up a few other places as well. +enum class TensorExp::Kind { + // Leaf. + kTensor = 0, + kInvariant, + kLoopVar, + // Unary operations. + kAbsF, + kAbsC, + kAbsI, + kCeilF, + kFloorF, + kSqrtF, + kSqrtC, + kExpm1F, + kExpm1C, + kLog1pF, + kLog1pC, + kSinF, + kSinC, + kTanhF, + kTanhC, + kNegF, + kNegC, + kNegI, + kTruncF, + kExtF, + kCastFS, // signed + kCastFU, // unsigned + kCastSF, // signed + kCastUF, // unsigned + kCastS, // signed + kCastU, // unsigned + kCastIdx, + kTruncI, + kCIm, // complex.im + kCRe, // complex.re + kBitCast, + kBinaryBranch, // semiring unary branch created from a binary op + kUnary, // semiring unary op + kSelect, // custom selection criteria + // Binary operations. + kMulF, + kMulC, + kMulI, + kDivF, + kDivC, // complex + kDivS, // signed + kDivU, // unsigned + kAddF, + kAddC, + kAddI, + kSubF, + kSubC, + kSubI, + kAndI, + kOrI, + kXorI, + kShrS, // signed + kShrU, // unsigned + kShlI, + kBinary, // semiring binary op + kReduce, // semiring reduction op +}; + /// Lattice point. Each lattice point consists of a formal conjunction /// of `TensorLoopId`s, together with the identifier of the corresponding /// tensor expression. The formal conjunction is represented as a set of @@ -271,12 +269,12 @@ /// Constructs a new tensor expression, and returns its identifier. /// The type of the `e0` argument varies according to the value of the /// `k` argument, as described by the `TensorExp` ctor. - ExprId addExp(Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(), + ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(), Operation *op = nullptr); - ExprId addExp(Kind k, ExprId e, Value v, Operation *op = nullptr) { + ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) { return addExp(k, e, kInvalidId, v, op); } - ExprId addExp(Kind k, Value v, Operation *op = nullptr) { + ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) { return addExp(k, kInvalidId, kInvalidId, v, op); } @@ -290,30 +288,30 @@ /// of `LoopId` (effectively constructing a larger "intersection" of those /// loops) with a newly constructed tensor (sub)expression of given kind. /// Returns the identifier of the new lattice point. - LatPointId conjLat(Kind kind, LatPointId p0, LatPointId p1, + LatPointId conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1, Operation *op = nullptr); /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`. /// Returns the identifier of the new set. - LatSetId conjSet(Kind kind, LatSetId s0, LatSetId s1, + LatSetId conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op = nullptr); /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`. /// Returns the identifier of the new set. - LatSetId disjSet(Kind kind, LatSetId s0, LatSetId s1, + LatSetId disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op = nullptr); /// Disjunctive merge of two lattice sets with custom handling of the /// overlap, left, and right regions. Any region may be left missing /// in the output. Returns the identifier of the new set. - LatSetId combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig, - bool includeLeft, Kind ltrans, Operation *opleft, - bool includeRight, Kind rtrans, Operation *opright); + LatSetId combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *orig, + bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, + bool includeRight, TensorExp::Kind rtrans, Operation *opright); /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E /// as new expression. Returns the identifier of the new set. - LatSetId mapSet(Kind kind, LatSetId s, Value v = Value(), + LatSetId mapSet(TensorExp::Kind kind, LatSetId s, Value v = Value(), Operation *op = nullptr); /// Optimizes the iteration lattice points in the given set. This @@ -377,7 +375,7 @@ /// Returns true if the expression is `(kTensor t)`. bool expIsTensor(ExprId e, TensorId t) const { - return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t; + return tensorExps[e].kind == TensorExp::Kind::kTensor && tensorExps[e].tensor == t; } /// Returns true if the expression contains the tensor as an operand. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -853,8 +853,8 @@ if (!rhs) { // Only unary and binary are allowed to return uninitialized rhs // to indicate missing output. - assert(env.exp(exp).kind == kUnary || env.exp(exp).kind == kBinary); - } else if (env.exp(exp).kind == kSelect) { + assert(env.exp(exp).kind == TensorExp::Kind::kUnary || env.exp(exp).kind == TensorExp::Kind::kBinary); + } else if (env.exp(exp).kind == TensorExp::Kind::kSelect) { // Select operation insertion. Value chain = env.getInsertionChain(); scf::IfOp ifOp = @@ -922,28 +922,28 @@ return Value(); const TensorExp &exp = env.exp(e); const auto kind = exp.kind; - if (kind == Kind::kTensor) + if (kind == TensorExp::Kind::kTensor) return genTensorLoad(env, rewriter, e); - if (kind == Kind::kInvariant) + if (kind == TensorExp::Kind::kInvariant) return genInvariantValue(env, e); - if (kind == Kind::kLoopVar) + if (kind == TensorExp::Kind::kLoopVar) return env.getLoopVar(exp.loop); - if (kind == Kind::kReduce) + if (kind == TensorExp::Kind::kReduce) env.startCustomReduc(e); // enter custom Value v0 = genExp(env, rewriter, exp.children.e0, ldx); Value v1 = genExp(env, rewriter, exp.children.e1, ldx); Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1); - if (ee && (kind == Kind::kUnary || kind == Kind::kBinary || - kind == Kind::kBinaryBranch || kind == Kind::kReduce || - kind == Kind::kSelect)) + if (ee && (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary || + kind == TensorExp::Kind::kBinaryBranch || kind == TensorExp::Kind::kReduce || + kind == TensorExp::Kind::kSelect)) ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); - if (kind == Kind::kReduce) + if (kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom - if (kind == kSelect) { + if (kind == TensorExp::Kind::kSelect) { assert(!exp.val); env.exp(e).val = v0; // Preserve value for later use. } @@ -956,7 +956,7 @@ LoopId ldx, bool atStart) { if (exp == kInvalidId) return; - if (env.exp(exp).kind == Kind::kTensor) { + if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. bool isAtLoop = ldx == kInvalidId; linalg::GenericOp op = env.op(); @@ -1000,18 +1000,18 @@ // Start or end loop invariant hoisting of a tensor load. env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); } - } else if (env.exp(exp).kind != Kind::kInvariant && - env.exp(exp).kind != Kind::kLoopVar) { + } else if (env.exp(exp).kind != TensorExp::Kind::kInvariant && + env.exp(exp).kind != TensorExp::Kind::kLoopVar) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - if (env.exp(exp).kind == Kind::kReduce) + if (env.exp(exp).kind == TensorExp::Kind::kReduce) env.startCustomReduc(exp); // enter custom const ExprId e0 = env.exp(exp).children.e0; const ExprId e1 = env.exp(exp).children.e1; genInvariants(env, builder, e0, ldx, atStart); genInvariants(env, builder, e1, ldx, atStart); - if (env.exp(exp).kind == Kind::kReduce) + if (env.exp(exp).kind == TensorExp::Kind::kReduce) env.endCustomReduc(); // exit custom } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -25,70 +25,70 @@ kBinary, }; -static ExpArity getExpArity(Kind k) { +static ExpArity getExpArity(TensorExp::Kind k) { switch (k) { // Leaf. - case kTensor: - case kInvariant: - case kLoopVar: + case TensorExp::Kind::kTensor: + case TensorExp::Kind::kInvariant: + case TensorExp::Kind::kLoopVar: return ExpArity::kNullary; - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: - case kCRe: - case kBitCast: - case kBinaryBranch: - case kUnary: - case kSelect: - case kNegF: - case kNegC: - case kNegI: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: + case TensorExp::Kind::kBitCast: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kUnary: + case TensorExp::Kind::kSelect: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: return ExpArity::kUnary; // Binary operations. - case kDivF: - case kDivC: - case kDivS: - case kDivU: - case kShrS: - case kShrU: - case kShlI: - case kMulF: - case kMulC: - case kMulI: - case kAndI: - case kAddF: - case kAddC: - case kAddI: - case kOrI: - case kXorI: - case kBinary: - case kReduce: - case kSubF: - case kSubC: - case kSubI: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: + case TensorExp::Kind::kShrS: + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kAndI: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: + case TensorExp::Kind::kBinary: + case TensorExp::Kind::kReduce: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: return ExpArity::kBinary; } llvm_unreachable("unexpected kind"); @@ -102,64 +102,64 @@ : kind(k), val(v), op(o) { switch (kind) { // Leaf. - case kTensor: + case TensorExp::Kind::kTensor: assert(x != kInvalidId && y == kInvalidId && !v && !o); tensor = x; break; - case kInvariant: + case TensorExp::Kind::kInvariant: assert(x == kInvalidId && y == kInvalidId && v && !o); break; - case kLoopVar: + case TensorExp::Kind::kLoopVar: assert(x != kInvalidId && y == kInvalidId && !v && !o); loop = x; break; // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kNegF: - case kNegC: - case kNegI: - case kCIm: - case kCRe: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: assert(x != kInvalidId && y == kInvalidId && !v && !o); children.e0 = x; children.e1 = y; break; - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kBitCast: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kBitCast: assert(x != kInvalidId && y == kInvalidId && v && !o); children.e0 = x; children.e1 = y; break; - case kBinaryBranch: - case kSelect: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kSelect: assert(x != kInvalidId && y == kInvalidId && !v && o); children.e0 = x; children.e1 = y; break; - case kUnary: + case TensorExp::Kind::kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. assert(x != kInvalidId && !v && o); @@ -167,31 +167,31 @@ children.e1 = y; break; // Binary operations. - case kMulF: - case kMulC: - case kMulI: - case kDivF: - case kDivC: - case kDivS: - case kDivU: - case kAddF: - case kAddC: - case kAddI: - case kSubF: - case kSubC: - case kSubI: - case kAndI: - case kOrI: - case kXorI: - case kShrS: - case kShrU: - case kShlI: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: + case TensorExp::Kind::kAndI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: + case TensorExp::Kind::kShrS: + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: assert(x != kInvalidId && y != kInvalidId && !v && !o); children.e0 = x; children.e1 = y; break; - case kBinary: - case kReduce: + case TensorExp::Kind::kBinary: + case TensorExp::Kind::kReduce: assert(x != kInvalidId && y != kInvalidId && !v && o); children.e0 = x; children.e1 = y; @@ -226,9 +226,9 @@ // Lattice methods. //===----------------------------------------------------------------------===// -ExprId Merger::addExp(Kind k, unsigned x, ExprId y, Value v, Operation *op) { +ExprId Merger::addExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, Operation *op) { const ExprId e = tensorExps.size(); - assert((k != kTensor || x < numTensors) && (k != kLoopVar || x < numLoops)); + assert((k != TensorExp::Kind::kTensor || x < numTensors) && (k != TensorExp::Kind::kLoopVar || x < numLoops)); tensorExps.emplace_back(k, x, y, v, op); return e; } @@ -246,7 +246,7 @@ return s; } -LatPointId Merger::conjLat(Kind kind, LatPointId p0, LatPointId p1, +LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1, Operation *op) { const LatPointId p = latPoints.size(); BitVector bits(latPoints[p0].bits); @@ -257,7 +257,7 @@ return p; } -LatSetId Merger::conjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) { +LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { const LatSetId s = addSet(); for (const LatPointId p0 : latSets[s0]) for (const LatPointId p1 : latSets[s1]) @@ -265,28 +265,28 @@ return s; } -LatSetId Merger::disjSet(Kind kind, LatSetId s0, LatSetId s1, Operation *op) { +LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { const LatSetId s = conjSet(kind, s0, s1, op); // Followed by all in s0. for (const LatPointId p : latSets[s0]) latSets[s].push_back(p); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices - if (kind == kSubF) - s1 = mapSet(kNegF, s1); - else if (kind == kSubC) - s1 = mapSet(kNegC, s1); - else if (kind == kSubI) - s1 = mapSet(kNegI, s1); + if (kind == TensorExp::Kind::kSubF) + s1 = mapSet(TensorExp::Kind::kNegF, s1); + else if (kind == TensorExp::Kind::kSubC) + s1 = mapSet(TensorExp::Kind::kNegC, s1); + else if (kind == TensorExp::Kind::kSubI) + s1 = mapSet(TensorExp::Kind::kNegI, s1); // Followed by all in s1. for (const LatPointId p : latSets[s1]) latSets[s].push_back(p); return s; } -LatSetId Merger::combiSet(Kind kind, LatSetId s0, LatSetId s1, Operation *orig, - bool includeLeft, Kind ltrans, Operation *opleft, - bool includeRight, Kind rtrans, Operation *opright) { +LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *orig, + bool includeLeft, TensorExp::Kind ltrans, Operation *opleft, + bool includeRight, TensorExp::Kind rtrans, Operation *opright) { const LatSetId s = conjSet(kind, s0, s1, orig); // Left Region. if (includeLeft) { @@ -305,8 +305,8 @@ return s; } -LatSetId Merger::mapSet(Kind kind, LatSetId s0, Value v, Operation *op) { - assert(kAbsF <= kind && kind <= kSelect); +LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, Operation *op) { + assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); const LatSetId s = addSet(); for (const LatPointId p : latSets[s0]) { const ExprId e = addExp(kind, latPoints[p].exp, v, op); @@ -406,7 +406,7 @@ } bool Merger::expContainsTensor(ExprId e, TensorId t) const { - if (tensorExps[e].kind == kTensor) + if (tensorExps[e].kind == TensorExp::Kind::kTensor) return tensorExps[e].tensor == t; switch (getExpArity(tensorExps[e].kind)) { @@ -431,13 +431,13 @@ bool Merger::hasNegateOnOut(ExprId e) const { switch (tensorExps[e].kind) { - case kNegF: - case kNegC: - case kNegI: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: return expContainsTensor(tensorExps[e].children.e0, outTensor); - case kSubF: - case kSubC: - case kSubI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: return expContainsTensor(tensorExps[e].children.e1, outTensor) || hasNegateOnOut(tensorExps[e].children.e0); default: { @@ -459,82 +459,82 @@ assert(t < numTensors && e < tensorExps.size()); switch (tensorExps[e].kind) { // Leaf. - case kTensor: + case TensorExp::Kind::kTensor: return tensorExps[e].tensor == t; - case kInvariant: - case kLoopVar: + case TensorExp::Kind::kInvariant: + case TensorExp::Kind::kLoopVar: return false; // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kNegF: - case kNegC: - case kNegI: - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: - case kCRe: - case kBitCast: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: + case TensorExp::Kind::kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); - case kBinaryBranch: - case kUnary: - case kSelect: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kUnary: + case TensorExp::Kind::kSelect: return false; // Binary operations. - case kDivF: // note: x / c only - case kDivC: - case kDivS: - case kDivU: + case TensorExp::Kind::kDivF: // note: x / c only + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: assert(!maybeZero(tensorExps[e].children.e1)); return isSingleCondition(t, tensorExps[e].children.e0); - case kShrS: // note: x >> inv only - case kShrU: - case kShlI: + case TensorExp::Kind::kShrS: // note: x >> inv only + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: assert(isInvariant(tensorExps[e].children.e1)); return isSingleCondition(t, tensorExps[e].children.e0); - case kMulF: - case kMulC: - case kMulI: - case kAndI: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kAndI: if (isSingleCondition(t, tensorExps[e].children.e0)) return isSingleCondition(t, tensorExps[e].children.e1) || isInvariant(tensorExps[e].children.e1); if (isSingleCondition(t, tensorExps[e].children.e1)) return isInvariant(tensorExps[e].children.e0); return false; - case kAddF: - case kAddC: - case kAddI: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: return isSingleCondition(t, tensorExps[e].children.e0) && isSingleCondition(t, tensorExps[e].children.e1); - case kSubF: - case kSubC: - case kSubI: - case kOrI: - case kXorI: - case kBinary: - case kReduce: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: + case TensorExp::Kind::kBinary: + case TensorExp::Kind::kReduce: return false; } llvm_unreachable("unexpected kind"); @@ -556,98 +556,98 @@ // Print methods (for debugging). //===----------------------------------------------------------------------===// -static const char *kindToOpSymbol(Kind kind) { +static const char *kindToOpSymbol(TensorExp::Kind kind) { switch (kind) { // Leaf. - case kTensor: + case TensorExp::Kind::kTensor: return "tensor"; - case kInvariant: + case TensorExp::Kind::kInvariant: return "invariant"; - case kLoopVar: + case TensorExp::Kind::kLoopVar: return "index"; // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: return "abs"; - case kCeilF: + case TensorExp::Kind::kCeilF: return "ceil"; - case kFloorF: + case TensorExp::Kind::kFloorF: return "floor"; - case kSqrtF: - case kSqrtC: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: return "sqrt"; - case kExpm1F: - case kExpm1C: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: return "expm1"; - case kLog1pF: - case kLog1pC: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: return "log1p"; - case kSinF: - case kSinC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: return "sin"; - case kTanhF: - case kTanhC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: return "tanh"; - case kNegF: - case kNegC: - case kNegI: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: return "-"; - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: return "complex.im"; - case kCRe: + case TensorExp::Kind::kCRe: return "complex.re"; - case kBitCast: + case TensorExp::Kind::kBitCast: return "cast"; - case kBinaryBranch: + case TensorExp::Kind::kBinaryBranch: return "binary_branch"; - case kUnary: + case TensorExp::Kind::kUnary: return "unary"; - case kSelect: + case TensorExp::Kind::kSelect: return "select"; // Binary operations. - case kMulF: - case kMulC: - case kMulI: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: return "*"; - case kDivF: - case kDivC: - case kDivS: - case kDivU: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: return "/"; - case kAddF: - case kAddC: - case kAddI: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: return "+"; - case kSubF: - case kSubC: - case kSubI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: return "-"; - case kAndI: + case TensorExp::Kind::kAndI: return "&"; - case kOrI: + case TensorExp::Kind::kOrI: return "|"; - case kXorI: + case TensorExp::Kind::kXorI: return "^"; - case kShrS: + case TensorExp::Kind::kShrS: return "a>>"; - case kShrU: + case TensorExp::Kind::kShrU: return ">>"; - case kShlI: + case TensorExp::Kind::kShlI: return "<<"; - case kBinary: + case TensorExp::Kind::kBinary: return "binary"; - case kReduce: + case TensorExp::Kind::kReduce: return "reduce"; } llvm_unreachable("unexpected kind for symbol"); @@ -656,79 +656,79 @@ void Merger::dumpExp(ExprId e) const { switch (tensorExps[e].kind) { // Leaf. - case kTensor: + case TensorExp::Kind::kTensor: if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; else if (tensorExps[e].tensor == outTensor) llvm::dbgs() << "output_"; llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; - case kInvariant: + case TensorExp::Kind::kInvariant: llvm::dbgs() << "invariant"; break; - case kLoopVar: + case TensorExp::Kind::kLoopVar: llvm::dbgs() << "loopvar_" << tensorExps[e].loop; break; // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kNegF: - case kNegC: - case kNegI: - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: - case kCRe: - case kBitCast: - case kBinaryBranch: - case kUnary: - case kSelect: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: + case TensorExp::Kind::kBitCast: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kUnary: + case TensorExp::Kind::kSelect: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; // Binary operations. - case kMulF: - case kMulC: - case kMulI: - case kDivF: - case kDivC: - case kDivS: - case kDivU: - case kAddF: - case kAddC: - case kAddI: - case kSubF: - case kSubC: - case kSubI: - case kAndI: - case kOrI: - case kXorI: - case kShrS: - case kShrU: - case kShlI: - case kBinary: - case kReduce: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: + case TensorExp::Kind::kAndI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: + case TensorExp::Kind::kShrS: + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: + case TensorExp::Kind::kBinary: + case TensorExp::Kind::kReduce: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -774,12 +774,12 @@ //===----------------------------------------------------------------------===// LatSetId Merger::buildLattices(ExprId e, LoopId i) { - const Kind kind = tensorExps[e].kind; + const TensorExp::Kind kind = tensorExps[e].kind; switch (kind) { // Leaf. - case kTensor: - case kInvariant: - case kLoopVar: { + case TensorExp::Kind::kTensor: + case TensorExp::Kind::kInvariant: + case TensorExp::Kind::kLoopVar: { // Either the loop-var is really used in the tensor expression, or it is // set to the undefined loop-var in that level. An invariant expression, // a proper index value, and a truly dynamic sparse output tensor are set @@ -787,7 +787,7 @@ // iteration space is not skipped as a result of their contents. const LatSetId s = addSet(); TensorId t = syntheticTensor; - if (kind == kTensor) { + if (kind == TensorExp::Kind::kTensor) { t = tensorExps[e].tensor; if (hasSparseOut && t == outTensor) t = syntheticTensor; @@ -796,37 +796,37 @@ return s; } // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kNegF: - case kNegC: - case kNegI: - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: - case kCRe: - case kBitCast: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: + case TensorExp::Kind::kBitCast: // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the // lattice set of the operand through the operator into a new set. // @@ -835,13 +835,13 @@ // | 0 |-y | return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); - case kBinaryBranch: - case kSelect: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kSelect: // The left or right half of a binary operation which has already // been split into separate operations for each region. return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), tensorExps[e].op); - case kUnary: + case TensorExp::Kind::kUnary: // A custom unary operation. // // op y| !y | y | @@ -860,14 +860,14 @@ Block &absentBlock = absentRegion.front(); YieldOp absentYield = cast(absentBlock.getTerminator()); Value absentVal = absentYield.getResult(); - const ExprId rhs = addExp(kInvariant, absentVal); + const ExprId rhs = addExp(TensorExp::Kind::kInvariant, absentVal); return disjSet(kind, child0, buildLattices(rhs, i), unop); } // Binary operations. - case kMulF: - case kMulC: - case kMulI: - case kAndI: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kAndI: // A multiplicative operation only needs to be performed // for the conjunction of sparse iteration spaces. // @@ -879,10 +879,10 @@ // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case kDivF: - case kDivC: - case kDivS: - case kDivU: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: // A division is tricky, since 0/0, 0/c, c/0 all have // specific outcomes for floating-point and integers. // Thus, we need to traverse the full iteration space. @@ -899,14 +899,14 @@ assert(!maybeZero(tensorExps[e].children.e1)); return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case kAddF: - case kAddC: - case kAddI: - case kSubF: - case kSubC: - case kSubI: - case kOrI: - case kXorI: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: // An additive operation needs to be performed // for the disjunction of sparse iteration spaces. // @@ -916,16 +916,16 @@ // x | x |x+y| x | x |x-y| return disjSet(kind, buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case kShrS: - case kShrU: - case kShlI: + case TensorExp::Kind::kShrS: + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: // A shift operation by an invariant amount (viz. tensor expressions // can only occur at the left-hand-side of the operator) can be handled // with the conjuction rule. assert(isInvariant(tensorExps[e].children.e1)); return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case kBinary: + case TensorExp::Kind::kBinary: // A custom binary operation. // // x op y| !y | y | @@ -952,11 +952,11 @@ } bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); - return combiSet(kBinary, child0, child1, binop, includeLeft, - kBinaryBranch, leftYield, includeRight, kBinaryBranch, + return combiSet(TensorExp::Kind::kBinary, child0, child1, binop, includeLeft, + TensorExp::Kind::kBinaryBranch, leftYield, includeRight, TensorExp::Kind::kBinaryBranch, rightYield); } - case kReduce: + case TensorExp::Kind::kReduce: // A custom reduce operation. return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i), @@ -974,7 +974,7 @@ /// Only returns false if we are certain this is a nonzero. bool Merger::maybeZero(ExprId e) const { - if (tensorExps[e].kind == kInvariant) { + if (tensorExps[e].kind == TensorExp::Kind::kInvariant) { if (auto c = tensorExps[e].val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); return arrayAttr[0].cast().getValue().isZero() && @@ -989,7 +989,7 @@ } bool Merger::isInvariant(ExprId e) const { - return tensorExps[e].kind == kInvariant; + return tensorExps[e].kind == TensorExp::Kind::kInvariant; } Type Merger::inferType(ExprId e, Value src) const { @@ -1041,21 +1041,21 @@ if (arg.getOwner()->getParentOp() == op) { OpOperand &t = op->getOpOperand(argN); if (!op.isScalar(&t)) - return addExp(kTensor, argN); + return addExp(TensorExp::Kind::kTensor, argN); v = t.get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. - return addExp(kInvariant, v); + return addExp(TensorExp::Kind::kInvariant, v); } // Something defined outside is invariant. Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.getRegion().front()) - return addExp(kInvariant, v); + return addExp(TensorExp::Kind::kInvariant, v); // Construct index operations. if (def->getNumOperands() == 0) { if (auto indexOp = dyn_cast(def)) - return addExp(kLoopVar, indexOp.getDim()); + return addExp(TensorExp::Kind::kLoopVar, indexOp.getDim()); } // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { @@ -1063,73 +1063,73 @@ if (x.has_value()) { const ExprId e = *x; if (isa(def)) - return addExp(kAbsF, e); + return addExp(TensorExp::Kind::kAbsF, e); if (isa(def)) - return addExp(kAbsC, e); + return addExp(TensorExp::Kind::kAbsC, e); if (isa(def)) - return addExp(kAbsI, e); + return addExp(TensorExp::Kind::kAbsI, e); if (isa(def)) - return addExp(kCeilF, e); + return addExp(TensorExp::Kind::kCeilF, e); if (isa(def)) - return addExp(kFloorF, e); + return addExp(TensorExp::Kind::kFloorF, e); if (isa(def)) - return addExp(kSqrtF, e); + return addExp(TensorExp::Kind::kSqrtF, e); if (isa(def)) - return addExp(kSqrtC, e); + return addExp(TensorExp::Kind::kSqrtC, e); if (isa(def)) - return addExp(kExpm1F, e); + return addExp(TensorExp::Kind::kExpm1F, e); if (isa(def)) - return addExp(kExpm1C, e); + return addExp(TensorExp::Kind::kExpm1C, e); if (isa(def)) - return addExp(kLog1pF, e); + return addExp(TensorExp::Kind::kLog1pF, e); if (isa(def)) - return addExp(kLog1pC, e); + return addExp(TensorExp::Kind::kLog1pC, e); if (isa(def)) - return addExp(kSinF, e); + return addExp(TensorExp::Kind::kSinF, e); if (isa(def)) - return addExp(kSinC, e); + return addExp(TensorExp::Kind::kSinC, e); if (isa(def)) - return addExp(kTanhF, e); + return addExp(TensorExp::Kind::kTanhF, e); if (isa(def)) - return addExp(kTanhC, e); + return addExp(TensorExp::Kind::kTanhC, e); if (isa(def)) - return addExp(kNegF, e); // no negi in std + return addExp(TensorExp::Kind::kNegF, e); // no negi in std if (isa(def)) - return addExp(kNegC, e); + return addExp(TensorExp::Kind::kNegC, e); if (isa(def)) - return addExp(kTruncF, e, v); + return addExp(TensorExp::Kind::kTruncF, e, v); if (isa(def)) - return addExp(kExtF, e, v); + return addExp(TensorExp::Kind::kExtF, e, v); if (isa(def)) - return addExp(kCastFS, e, v); + return addExp(TensorExp::Kind::kCastFS, e, v); if (isa(def)) - return addExp(kCastFU, e, v); + return addExp(TensorExp::Kind::kCastFU, e, v); if (isa(def)) - return addExp(kCastSF, e, v); + return addExp(TensorExp::Kind::kCastSF, e, v); if (isa(def)) - return addExp(kCastUF, e, v); + return addExp(TensorExp::Kind::kCastUF, e, v); if (isa(def)) - return addExp(kCastS, e, v); + return addExp(TensorExp::Kind::kCastS, e, v); if (isa(def)) - return addExp(kCastU, e, v); + return addExp(TensorExp::Kind::kCastU, e, v); if (isa(def)) - return addExp(kCastIdx, e, v); + return addExp(TensorExp::Kind::kCastIdx, e, v); if (isa(def)) - return addExp(kTruncI, e, v); + return addExp(TensorExp::Kind::kTruncI, e, v); if (isa(def)) - return addExp(kCIm, e); + return addExp(TensorExp::Kind::kCIm, e); if (isa(def)) - return addExp(kCRe, e); + return addExp(TensorExp::Kind::kCRe, e); if (isa(def)) - return addExp(kBitCast, e, v); + return addExp(TensorExp::Kind::kBitCast, e, v); if (auto unop = dyn_cast(def)) { if (isAdmissibleBranch(unop, unop.getPresentRegion()) && isAdmissibleBranch(unop, unop.getAbsentRegion())) - return addExp(kUnary, e, Value(), def); + return addExp(TensorExp::Kind::kUnary, e, Value(), def); } if (auto selop = dyn_cast(def)) { if (isAdmissibleBranch(selop, selop.getRegion())) - return addExp(kSelect, e, Value(), def); + return addExp(TensorExp::Kind::kSelect, e, Value(), def); } } } @@ -1143,50 +1143,50 @@ const ExprId e0 = *x; const ExprId e1 = *y; if (isa(def)) - return addExp(kMulF, e0, e1); + return addExp(TensorExp::Kind::kMulF, e0, e1); if (isa(def)) - return addExp(kMulC, e0, e1); + return addExp(TensorExp::Kind::kMulC, e0, e1); if (isa(def)) - return addExp(kMulI, e0, e1); + return addExp(TensorExp::Kind::kMulI, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(kDivF, e0, e1); + return addExp(TensorExp::Kind::kDivF, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(kDivC, e0, e1); + return addExp(TensorExp::Kind::kDivC, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(kDivS, e0, e1); + return addExp(TensorExp::Kind::kDivS, e0, e1); if (isa(def) && !maybeZero(e1)) - return addExp(kDivU, e0, e1); + return addExp(TensorExp::Kind::kDivU, e0, e1); if (isa(def)) - return addExp(kAddF, e0, e1); + return addExp(TensorExp::Kind::kAddF, e0, e1); if (isa(def)) - return addExp(kAddC, e0, e1); + return addExp(TensorExp::Kind::kAddC, e0, e1); if (isa(def)) - return addExp(kAddI, e0, e1); + return addExp(TensorExp::Kind::kAddI, e0, e1); if (isa(def)) - return addExp(kSubF, e0, e1); + return addExp(TensorExp::Kind::kSubF, e0, e1); if (isa(def)) - return addExp(kSubC, e0, e1); + return addExp(TensorExp::Kind::kSubC, e0, e1); if (isa(def)) - return addExp(kSubI, e0, e1); + return addExp(TensorExp::Kind::kSubI, e0, e1); if (isa(def)) - return addExp(kAndI, e0, e1); + return addExp(TensorExp::Kind::kAndI, e0, e1); if (isa(def)) - return addExp(kOrI, e0, e1); + return addExp(TensorExp::Kind::kOrI, e0, e1); if (isa(def)) - return addExp(kXorI, e0, e1); + return addExp(TensorExp::Kind::kXorI, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(kShrS, e0, e1); + return addExp(TensorExp::Kind::kShrS, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(kShrU, e0, e1); + return addExp(TensorExp::Kind::kShrU, e0, e1); if (isa(def) && isInvariant(e1)) - return addExp(kShlI, e0, e1); + return addExp(TensorExp::Kind::kShlI, e0, e1); if (auto binop = dyn_cast(def)) { if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && (binop.getLeftIdentity() || isAdmissibleBranch(binop, binop.getLeftRegion())) && (binop.getRightIdentity() || isAdmissibleBranch(binop, binop.getRightRegion()))) - return addExp(kBinary, e0, e1, Value(), def); + return addExp(TensorExp::Kind::kBinary, e0, e1, Value(), def); } } } @@ -1200,7 +1200,7 @@ const ExprId e1 = *y; if (auto redop = dyn_cast(def)) { if (isAdmissibleBranch(redop, redop.getRegion())) - return addExp(kReduce, e0, e1, Value(), def); + return addExp(TensorExp::Kind::kReduce, e0, e1, Value(), def); } } } @@ -1257,136 +1257,136 @@ Value v1) const { switch (tensorExps[e].kind) { // Leaf. - case kTensor: - case kInvariant: - case kLoopVar: + case TensorExp::Kind::kTensor: + case TensorExp::Kind::kInvariant: + case TensorExp::Kind::kLoopVar: llvm_unreachable("unexpected non-op"); // Unary operations. - case kAbsF: + case TensorExp::Kind::kAbsF: return rewriter.create(loc, v0); - case kAbsC: { + case TensorExp::Kind::kAbsC: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } - case kAbsI: + case TensorExp::Kind::kAbsI: return rewriter.create(loc, v0); - case kCeilF: + case TensorExp::Kind::kCeilF: return rewriter.create(loc, v0); - case kFloorF: + case TensorExp::Kind::kFloorF: return rewriter.create(loc, v0); - case kSqrtF: + case TensorExp::Kind::kSqrtF: return rewriter.create(loc, v0); - case kSqrtC: + case TensorExp::Kind::kSqrtC: return rewriter.create(loc, v0); - case kExpm1F: + case TensorExp::Kind::kExpm1F: return rewriter.create(loc, v0); - case kExpm1C: + case TensorExp::Kind::kExpm1C: return rewriter.create(loc, v0); - case kLog1pF: + case TensorExp::Kind::kLog1pF: return rewriter.create(loc, v0); - case kLog1pC: + case TensorExp::Kind::kLog1pC: return rewriter.create(loc, v0); - case kSinF: + case TensorExp::Kind::kSinF: return rewriter.create(loc, v0); - case kSinC: + case TensorExp::Kind::kSinC: return rewriter.create(loc, v0); - case kTanhF: + case TensorExp::Kind::kTanhF: return rewriter.create(loc, v0); - case kTanhC: + case TensorExp::Kind::kTanhC: return rewriter.create(loc, v0); - case kNegF: + case TensorExp::Kind::kNegF: return rewriter.create(loc, v0); - case kNegC: + case TensorExp::Kind::kNegC: return rewriter.create(loc, v0); - case kNegI: // no negi in std + case TensorExp::Kind::kNegI: // no negi in std return rewriter.create( loc, rewriter.create(loc, v0.getType(), rewriter.getZeroAttr(v0.getType())), v0); - case kTruncF: + case TensorExp::Kind::kTruncF: return rewriter.create(loc, inferType(e, v0), v0); - case kExtF: + case TensorExp::Kind::kExtF: return rewriter.create(loc, inferType(e, v0), v0); - case kCastFS: + case TensorExp::Kind::kCastFS: return rewriter.create(loc, inferType(e, v0), v0); - case kCastFU: + case TensorExp::Kind::kCastFU: return rewriter.create(loc, inferType(e, v0), v0); - case kCastSF: + case TensorExp::Kind::kCastSF: return rewriter.create(loc, inferType(e, v0), v0); - case kCastUF: + case TensorExp::Kind::kCastUF: return rewriter.create(loc, inferType(e, v0), v0); - case kCastS: + case TensorExp::Kind::kCastS: return rewriter.create(loc, inferType(e, v0), v0); - case kCastU: + case TensorExp::Kind::kCastU: return rewriter.create(loc, inferType(e, v0), v0); - case kCastIdx: + case TensorExp::Kind::kCastIdx: return rewriter.create(loc, inferType(e, v0), v0); - case kTruncI: + case TensorExp::Kind::kTruncI: return rewriter.create(loc, inferType(e, v0), v0); - case kCIm: { + case TensorExp::Kind::kCIm: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } - case kCRe: { + case TensorExp::Kind::kCRe: { auto type = v0.getType().cast(); auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } - case kBitCast: + case TensorExp::Kind::kBitCast: return rewriter.create(loc, inferType(e, v0), v0); // Binary operations. - case kMulF: + case TensorExp::Kind::kMulF: return rewriter.create(loc, v0, v1); - case kMulC: + case TensorExp::Kind::kMulC: return rewriter.create(loc, v0, v1); - case kMulI: + case TensorExp::Kind::kMulI: return rewriter.create(loc, v0, v1); - case kDivF: + case TensorExp::Kind::kDivF: return rewriter.create(loc, v0, v1); - case kDivC: + case TensorExp::Kind::kDivC: return rewriter.create(loc, v0, v1); - case kDivS: + case TensorExp::Kind::kDivS: return rewriter.create(loc, v0, v1); - case kDivU: + case TensorExp::Kind::kDivU: return rewriter.create(loc, v0, v1); - case kAddF: + case TensorExp::Kind::kAddF: return rewriter.create(loc, v0, v1); - case kAddC: + case TensorExp::Kind::kAddC: return rewriter.create(loc, v0, v1); - case kAddI: + case TensorExp::Kind::kAddI: return rewriter.create(loc, v0, v1); - case kSubF: + case TensorExp::Kind::kSubF: return rewriter.create(loc, v0, v1); - case kSubC: + case TensorExp::Kind::kSubC: return rewriter.create(loc, v0, v1); - case kSubI: + case TensorExp::Kind::kSubI: return rewriter.create(loc, v0, v1); - case kAndI: + case TensorExp::Kind::kAndI: return rewriter.create(loc, v0, v1); - case kOrI: + case TensorExp::Kind::kOrI: return rewriter.create(loc, v0, v1); - case kXorI: + case TensorExp::Kind::kXorI: return rewriter.create(loc, v0, v1); - case kShrS: + case TensorExp::Kind::kShrS: return rewriter.create(loc, v0, v1); - case kShrU: + case TensorExp::Kind::kShrU: return rewriter.create(loc, v0, v1); - case kShlI: + case TensorExp::Kind::kShlI: return rewriter.create(loc, v0, v1); - case kBinaryBranch: // semi-ring ops with custom logic. + case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. return insertYieldOp(rewriter, loc, *tensorExps[e].op->getBlock()->getParent(), {v0}); - case kUnary: + case TensorExp::Kind::kUnary: return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); - case kSelect: + case TensorExp::Kind::kSelect: return insertYieldOp(rewriter, loc, cast(tensorExps[e].op).getRegion(), {v0}); - case kBinary: + case TensorExp::Kind::kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); - case kReduce: { + case TensorExp::Kind::kReduce: { ReduceOp redOp = cast(tensorExps[e].op); return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -23,18 +23,18 @@ /// #define FOREVERY_BINOP(DO) \ - DO(mulf, Kind::kMulF) \ - DO(mulc, Kind::kMulC) \ - DO(muli, Kind::kMulI) \ - DO(addf, Kind::kAddF) \ - DO(addc, Kind::kAddC) \ - DO(addi, Kind::kAddI) \ - DO(subf, Kind::kSubF) \ - DO(subc, Kind::kSubC) \ - DO(subi, Kind::kSubI) \ - DO(andi, Kind::kAndI) \ - DO(xori, Kind::kXorI) \ - DO(ori, Kind::kOrI) + DO(mulf, TensorExp::Kind::kMulF) \ + DO(mulc, TensorExp::Kind::kMulC) \ + DO(muli, TensorExp::Kind::kMulI) \ + DO(addf, TensorExp::Kind::kAddF) \ + DO(addc, TensorExp::Kind::kAddC) \ + DO(addi, TensorExp::Kind::kAddI) \ + DO(subf, TensorExp::Kind::kSubF) \ + DO(subc, TensorExp::Kind::kSubC) \ + DO(subi, TensorExp::Kind::kSubI) \ + DO(andi, TensorExp::Kind::kAndI) \ + DO(xori, TensorExp::Kind::kXorI) \ + DO(ori, TensorExp::Kind::kOrI) // TODO: Disjunctive binary operations that need special handling are not // included, e.g., Division are not tested (for now) as it need a constant @@ -82,7 +82,7 @@ /// Simple recursive data structure used to match expressions in Mergers. struct Pattern { - Kind kind; + TensorExp::Kind kind; /// Expressions representing tensors simply have a tensor number. unsigned tensorNum; @@ -94,11 +94,11 @@ /// Constructors. /// Rather than using these, please use the readable helper constructor /// functions below to make tests more readable. - Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} - Pattern(Kind kind, const std::shared_ptr &e0, + Pattern(unsigned tensorNum) : kind(TensorExp::Kind::kTensor), tensorNum(tensorNum) {} + Pattern(TensorExp::Kind kind, const std::shared_ptr &e0, const std::shared_ptr &e1) : kind(kind), e0(e0), e1(e1) { - assert(kind >= Kind::kMulF); + assert(kind >= TensorExp::Kind::kMulF); assert(e0 && e1); } }; @@ -134,7 +134,7 @@ /// unsigned tensor(unsigned tensor) { - return merger.addExp(Kind::kTensor, tensor); + return merger.addExp(TensorExp::Kind::kTensor, tensor); } #define IMPL_BINOP_EXPR(OP, KIND) \ @@ -222,69 +222,69 @@ return false; switch (tensorExp.kind) { // Leaf. - case kTensor: + case TensorExp::Kind::kTensor: return tensorExp.tensor == pattern->tensorNum; - case kInvariant: - case kLoopVar: + case TensorExp::Kind::kInvariant: + case TensorExp::Kind::kLoopVar: llvm_unreachable("invariant not handled yet"); // Unary operations. - case kAbsF: - case kAbsC: - case kAbsI: - case kCeilF: - case kFloorF: - case kSqrtF: - case kSqrtC: - case kExpm1F: - case kExpm1C: - case kLog1pF: - case kLog1pC: - case kSinF: - case kSinC: - case kTanhF: - case kTanhC: - case kNegF: - case kNegC: - case kNegI: - case kTruncF: - case kExtF: - case kCastFS: - case kCastFU: - case kCastSF: - case kCastUF: - case kCastS: - case kCastU: - case kCastIdx: - case kTruncI: - case kCIm: - case kCRe: - case kBitCast: - case kSelect: - case kBinaryBranch: - case kUnary: + case TensorExp::Kind::kAbsF: + case TensorExp::Kind::kAbsC: + case TensorExp::Kind::kAbsI: + case TensorExp::Kind::kCeilF: + case TensorExp::Kind::kFloorF: + case TensorExp::Kind::kSqrtF: + case TensorExp::Kind::kSqrtC: + case TensorExp::Kind::kExpm1F: + case TensorExp::Kind::kExpm1C: + case TensorExp::Kind::kLog1pF: + case TensorExp::Kind::kLog1pC: + case TensorExp::Kind::kSinF: + case TensorExp::Kind::kSinC: + case TensorExp::Kind::kTanhF: + case TensorExp::Kind::kTanhC: + case TensorExp::Kind::kNegF: + case TensorExp::Kind::kNegC: + case TensorExp::Kind::kNegI: + case TensorExp::Kind::kTruncF: + case TensorExp::Kind::kExtF: + case TensorExp::Kind::kCastFS: + case TensorExp::Kind::kCastFU: + case TensorExp::Kind::kCastSF: + case TensorExp::Kind::kCastUF: + case TensorExp::Kind::kCastS: + case TensorExp::Kind::kCastU: + case TensorExp::Kind::kCastIdx: + case TensorExp::Kind::kTruncI: + case TensorExp::Kind::kCIm: + case TensorExp::Kind::kCRe: + case TensorExp::Kind::kBitCast: + case TensorExp::Kind::kSelect: + case TensorExp::Kind::kBinaryBranch: + case TensorExp::Kind::kUnary: return compareExpression(tensorExp.children.e0, pattern->e0); // Binary operations. - case kMulF: - case kMulC: - case kMulI: - case kDivF: - case kDivC: - case kDivS: - case kDivU: - case kAddF: - case kAddC: - case kAddI: - case kSubF: - case kSubC: - case kSubI: - case kAndI: - case kOrI: - case kXorI: - case kShrS: - case kShrU: - case kShlI: - case kBinary: - case kReduce: + case TensorExp::Kind::kMulF: + case TensorExp::Kind::kMulC: + case TensorExp::Kind::kMulI: + case TensorExp::Kind::kDivF: + case TensorExp::Kind::kDivC: + case TensorExp::Kind::kDivS: + case TensorExp::Kind::kDivU: + case TensorExp::Kind::kAddF: + case TensorExp::Kind::kAddC: + case TensorExp::Kind::kAddI: + case TensorExp::Kind::kSubF: + case TensorExp::Kind::kSubC: + case TensorExp::Kind::kSubI: + case TensorExp::Kind::kAndI: + case TensorExp::Kind::kOrI: + case TensorExp::Kind::kXorI: + case TensorExp::Kind::kShrS: + case TensorExp::Kind::kShrU: + case TensorExp::Kind::kShlI: + case TensorExp::Kind::kBinary: + case TensorExp::Kind::kReduce: return compareExpression(tensorExp.children.e0, pattern->e0) && compareExpression(tensorExp.children.e1, pattern->e1); } @@ -312,15 +312,15 @@ EXPECT_TRUE(merger.getOutTensorID() == t2); // Tensor 0: sparse input vector. - merger.addExp(Kind::kTensor, t0, -1u); + merger.addExp(TensorExp::Kind::kTensor, t0, -1u); merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. - merger.addExp(Kind::kTensor, t1, -1u); + merger.addExp(TensorExp::Kind::kTensor, t1, -1u); merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed); // Tensor 2: dense output vector. - merger.addExp(Kind::kTensor, t2, -1u); + merger.addExp(TensorExp::Kind::kTensor, t2, -1u); merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense); } }; @@ -337,19 +337,19 @@ EXPECT_TRUE(merger.getOutTensorID() == t3); // Tensor 0: sparse input vector. - merger.addExp(Kind::kTensor, t0, -1u); + merger.addExp(TensorExp::Kind::kTensor, t0, -1u); merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: sparse input vector. - merger.addExp(Kind::kTensor, t1, -1u); + merger.addExp(TensorExp::Kind::kTensor, t1, -1u); merger.setLevelAndType(t1, l0, 0, DimLevelType::Compressed); // Tensor 2: sparse input vector - merger.addExp(Kind::kTensor, t2, -1u); + merger.addExp(TensorExp::Kind::kTensor, t2, -1u); merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed); // Tensor 3: dense output vector - merger.addExp(Kind::kTensor, t3, -1u); + merger.addExp(TensorExp::Kind::kTensor, t3, -1u); merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense); } }; @@ -370,15 +370,15 @@ EXPECT_TRUE(merger.getOutTensorID() == t2); // Tensor 0: sparse input vector. - merger.addExp(Kind::kTensor, t0, -1u); + merger.addExp(TensorExp::Kind::kTensor, t0, -1u); merger.setLevelAndType(t0, l0, 0, DimLevelType::Compressed); // Tensor 1: dense input vector. - merger.addExp(Kind::kTensor, t1, -1u); + merger.addExp(TensorExp::Kind::kTensor, t1, -1u); merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense); // Tensor 2: dense output vector. - merger.addExp(Kind::kTensor, t2, -1u); + merger.addExp(TensorExp::Kind::kTensor, t2, -1u); merger.setLevelAndType(t2, l0, 0, DimLevelType::Dense); } }; @@ -399,19 +399,19 @@ EXPECT_TRUE(merger.getOutTensorID() == t3); // Tensor 0: undef input vector. - merger.addExp(Kind::kTensor, t0, -1u); + merger.addExp(TensorExp::Kind::kTensor, t0, -1u); merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef); // Tensor 1: dense input vector. - merger.addExp(Kind::kTensor, t1, -1u); + merger.addExp(TensorExp::Kind::kTensor, t1, -1u); merger.setLevelAndType(t1, l0, 0, DimLevelType::Dense); // Tensor 2: undef input vector. - merger.addExp(Kind::kTensor, t2, -1u); + merger.addExp(TensorExp::Kind::kTensor, t2, -1u); merger.setLevelAndType(t2, l0, 0, DimLevelType::Undef); // Tensor 3: dense output vector. - merger.addExp(Kind::kTensor, t3, -1u); + merger.addExp(TensorExp::Kind::kTensor, t3, -1u); merger.setLevelAndType(t3, l0, 0, DimLevelType::Dense); } }; @@ -435,15 +435,15 @@ merger.setHasSparseOut(true); // Tensor 0: undef input vector. - merger.addExp(Kind::kTensor, t0, -1u); + merger.addExp(TensorExp::Kind::kTensor, t0, -1u); merger.setLevelAndType(t0, l0, 0, DimLevelType::Undef); // Tensor 1: undef input vector. - merger.addExp(Kind::kTensor, t1, -1u); + merger.addExp(TensorExp::Kind::kTensor, t1, -1u); merger.setLevelAndType(t1, l0, 0, DimLevelType::Undef); // Tensor 2: sparse output vector. - merger.addExp(Kind::kTensor, t2, -1u); + merger.addExp(TensorExp::Kind::kTensor, t2, -1u); merger.setLevelAndType(t2, l0, 0, DimLevelType::Compressed); } };