diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -61,6 +61,7 @@ // Copy-assignment would be implicitly deleted (because our fields // are const), so we explicitly delete it for clarity. SparseTensorType &operator=(const SparseTensorType &) = delete; + // So we must explicitly define the copy-ctor to silence -Wdeprecated-copy. SparseTensorType(const SparseTensorType &) = default; /// Constructs a new `SparseTensorType` with the same dimension-shape diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Value.h" #include "llvm/ADT/BitVector.h" #include @@ -23,11 +24,27 @@ namespace sparse_tensor { /// Tensor expression kind. +/// +/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`. +/// That is, its argument is a `LoopId` identifying the loop-variable +/// in question, and its value will be the current iteration's value +/// of that loop-variable. See the `LoopId` documentation for more details. +// +// TODO: either make this `enum class TensorExpKind`, or else make +// it nested in the `TensorExp` class; to improve namespacing and to +// disambiguate vs other things also called "kinds". +// +// TODO: Modify this definition so that the numeric values already encode +// the `ExpArity` (while extending the notion of "arity" to include not +// just the number of `ExprId` children the node has, but also whether the +// node has a `Value` and/or `Operation*`). Doing this will avoid needing +// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor, +// and should help clean up a few other places as well. enum Kind { // Leaf. kTensor = 0, kInvariant, - kIndex, + kLoopVar, // Unary operations. kAbsF, kAbsC, @@ -87,27 +104,94 @@ kReduce, // semiring reduction op }; +// TODO: These type aliases currently only serve to make the code more +// self-documenting, however because they are not type-checked they can +// do nothing to prevent mixups. We should really change them from mere +// aliases to actual struct definitions, so that we can type-check them. + +/// Tensor identifiers. The valid set of identifiers is defined by the +/// first argument passed to the `Merger` ctor. +using TensorId = unsigned; + +/// Loop identifiers. The valid set of identifiers is defined by the +/// second two arguments to the `Merger` ctor. +/// +/// These identifiers serve as proxies for the `$dim` argument to +/// `linalg::IndexOp`, however the numerical value of a `LoopId` should +/// not necessarily be equated with the numerical value of the corresponding +/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that +/// is, they identify the loop which binds the loop-variable by counting +/// the enclosing loops from innermost to outermost, starting from zero. +/// Whereas `LoopId` are considered to be arbitrary names for identifying +/// loops; since the `Merger` does not care about the actual ordering of +/// loops, and leaves it up to the `LoopEmitter` to specify the actual +/// loop ordering (`LoopOrd`). +/// +/// TODO: Despite the above claim that `$dim` and `LoopId` need not be +/// numerically equal, some code in the `Merger` class does equate them +/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship +/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their +/// providence. If `LoopId` really is supposed to be equated with `$dim`, +/// then we should change the name to `LoopIdx` or similar, to capture the +/// fact that its numerical value is not invariant when entering/exiting +/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which +/// are invariant identifiers). +using LoopId = unsigned; + +/// A compressed representation of `std::pair`. +/// The compression scheme is such that this also serves as an index +/// into the bitvector stored in `LatPoint` (since that bitvector is +/// just the implementation for a set of `TensorLoopId` values). +using TensorLoopId = unsigned; + +/// `TensorExp` identifiers. These are allocated by `Merger::addExp`, +/// and serve as unique identifiers for the corresponding `TensorExp` object. +using ExprId = unsigned; + +/// `LatPoint` identifiers. These are allocated by `Merger::addLat`, +/// and serve as unique identifiers for the corresponding `LatPoint` object. +using LatPointId = unsigned; + +/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and +/// by other methods calling that one), and serve as unique identifiers +/// for the corresponding `SmallVector` object. +using LatSetId = unsigned; + +/// A constant serving as the canonically invalid identifier, regardless +/// of the identifier type. +static constexpr unsigned kInvalidId = -1u; + /// Children subexpressions of tensor operations. struct Children { - unsigned e0; - unsigned e1; + ExprId e0; + ExprId e1; }; /// Tensor expression. Represents a MLIR expression in tensor index notation. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation); + // The `x` parameter has different types depending on the value of the + // `k` parameter. The correspondences are: + // * `kTensor` -> `TensorId` + // * `kInvariant` -> `kInvalidId` + // * `kLoopVar` -> `LoopId` + // * else -> `ExprId` + // + // The `y`, `v`, and `op` parameters either must or must not be + // `kInvalidId`/`nullptr`, depending on the value of the `k` parameter; + // however, they have uniform C++ types regardless of the value of `k`. + TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op); /// Tensor expression kind. Kind kind; union { - /// Expressions representing tensors simply have a tensor number. - unsigned tensor; + /// `kTensor` expressions simply have a tensor identifier. + TensorId tensor; - /// Indices hold the index number. - unsigned index; + /// `kLoopVar` expressions simply have a loop identifier. + LoopId loop; - /// Tensor operations hold the indices of their children. + /// All other expressions hold the `ExprId`s of their children. Children children; }; @@ -123,24 +207,29 @@ Operation *op; }; -/// Lattice point. Each lattice point consists of a conjunction of tensor -/// loop indices (encoded in a bitvector) and the index of the corresponding -/// tensor expression. +/// Lattice point. Each lattice point consists of a formal conjunction +/// of `TensorLoopId`s, together with the identifier of the corresponding +/// tensor expression. The formal conjunction is represented as a set of +/// `TensorLoopId`, where that set is implemented as a `BitVector`. struct LatPoint { - LatPoint(unsigned n, unsigned e, unsigned b); - LatPoint(const BitVector &b, unsigned e); + /// Construct the lattice point from a given set of `TensorLoopId`s. + LatPoint(const BitVector &bits, ExprId e); - /// Conjunction of tensor loop indices as bitvector. This represents - /// all indices involved in the tensor expression + /// Construct a lattice point with `(t,i)` as the only `TensorLoopId`, + /// where `(t,i) < (numTensors,numLoops)`. + LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i, + ExprId e); + + /// Conjunction of all `TensorLoopId`s involved in the tensor expression. BitVector bits; - /// Simplified conjunction of tensor loop indices as bitvector. This + /// Simplified conjunction of `TensorLoopId` as bitvector. This /// represents a simplified condition under which this tensor expression /// must execute. Pre-computed during codegen to avoid repeated eval. BitVector simple; - /// Index of the tensor expression. - unsigned exp; + /// Identifier of the tensor expression. + ExprId exp; }; /// A class to handle all iteration lattice operations. This class abstracts @@ -157,186 +246,203 @@ /// /// In addition to natives loops (which are specified by the GenericOp), /// extra filter loops are needed in order to handle affine expressions on - /// sparse dimensions. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive + /// sparse levels. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive /// implementation of the filter loop could be generated as /// - /// for (coord : sparse_dim[0]) - /// if (coord == d0 + d1) { + /// for (const auto c0 : coordinates[0]) { + /// if (c0 == d0 + d1) { /// generated_code; /// } /// } /// /// to filter out coordinates that are not equal to the affine expression. - /// - /// TODO: we want to make the filter loop more efficient in the future, e.g., - /// by avoiding scanning the full stored index sparse (keeping the last - /// position in ordered list) or even apply binary search to find the index. - /// - Merger(unsigned t, unsigned l, unsigned fl); + // + // TODO: we want to make the filter loop more efficient in the future, + // e.g., by avoiding scanning the full list of stored coordinates (keeping + // the last position in ordered list) or even apply binary search to find + // the coordinate. + // + // TODO: would be cleaner to understand/document if the first argument + // gave the number of input tensors, instead of the current number of + // input+output tensors. + Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, + unsigned numFilterLoops); - /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), - Operation *op = nullptr); - unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { - return addExp(k, e, -1u, v, op); + /// Constructs a new tensor expression, and returns its identifier. + /// The type of the `e0` argument varies according to the value of the + /// `k` argument, as described by the `TensorExp` ctor. + ExprId addExp(Kind k, unsigned e0, ExprId e1 = kInvalidId, Value v = Value(), + Operation *op = nullptr); + ExprId addExp(Kind k, ExprId e, Value v, Operation *op = nullptr) { + return addExp(k, e, kInvalidId, v, op); } - unsigned addExp(Kind k, Value v, Operation *op = nullptr) { - return addExp(k, -1u, -1u, v, op); + ExprId addExp(Kind k, Value v, Operation *op = nullptr) { + return addExp(k, kInvalidId, kInvalidId, v, op); } - /// Adds an iteration lattice point. Returns its index. - unsigned addLat(unsigned t, unsigned i, unsigned e); + /// Constructs a new iteration lattice point, and returns its identifier. + LatPointId addLat(TensorId t, LoopId i, ExprId e); - /// Adds a new, initially empty, set. Returns its index. - unsigned addSet(); + /// Constructs a new (initially empty) set, and returns its identifier. + LatSetId addSet(); /// Computes a single conjunction of two lattice points by taking the "union" - /// of loop indices (effectively constructing a larger "intersection" of those - /// indices) with a newly constructed tensor (sub)expression of given kind. - /// Returns the index of the new lattice point. - unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, - Operation *op = nullptr); + /// of `LoopId` (effectively constructing a larger "intersection" of those + /// loops) with a newly constructed tensor (sub)expression of given kind. + /// Returns the identifier of the new lattice point. + LatPointId conjLatPoint(Kind kind, LatPointId p0, LatPointId p1, + Operation *op = nullptr); - /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of - /// cartesian product. Returns the index of the new set. - unsigned takeConj(Kind kind, unsigned s0, unsigned s1, + /// Conjunctive merge of two lattice sets: `(s0 /\_op s1)`. + /// Returns the identifier of the new set. + LatSetId takeConj(Kind kind, LatSetId s0, LatSetId s1, Operation *op = nullptr); - /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). - /// Returns the index of the new set. - unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, + /// Disjunctive merge of two lattice sets: `(s0 /\_op s1, s0, s1)`. + /// Returns the identifier of the new set. + LatSetId takeDisj(Kind kind, LatSetId s0, LatSetId s1, Operation *op = nullptr); - /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of - /// the overlap, left, and right regions. Any region may be left missing in - /// the output. Returns the index of the new set. - unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, + /// Disjunctive merge of two lattice sets with custom handling of the + /// overlap, left, and right regions. Any region may be left missing + /// in the output. Returns the identifier of the new set. + LatSetId takeCombi(Kind kind, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright); /// Maps the unary operator over the lattice set of the operand, i.e. each /// lattice point on an expression E is simply copied over, but with OP E - /// as new expression. Returns the index of the new set. - unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), + /// as new expression. Returns the identifier of the new set. + LatSetId mapSet(Kind kind, LatSetId s, Value v = Value(), Operation *op = nullptr); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid /// generating redundant loops and conditions. - unsigned optimizeSet(unsigned s0); + LatSetId optimizeSet(LatSetId s); /// Simplifies the conditions in a conjunction of a given lattice point /// within the given set using just two basic rules: /// (1) multiple dense conditions are reduced to single dense, and /// (2) a *singleton* sparse/dense is reduced to sparse/random access. - BitVector simplifyCond(unsigned s0, unsigned p0); + BitVector simplifyCond(LatSetId s, LatPointId p); - /// Returns true if Li > Lj. - bool latGT(unsigned i, unsigned j) const; + /// Returns true if p0 > p1. + bool latGT(LatPointId p0, LatPointId p1) const; - /// Returns true if Li and Lj only differ in dense. - bool onlyDenseDiff(unsigned i, unsigned j); + /// Returns true if p0 and p1 only differ in dense. + bool onlyDenseDiff(LatPointId p0, LatPointId p1) const; - /// Bit translation (get tensor ID). - unsigned tensor(unsigned b) const { return b % numTensors; } - /// Bit translation (get loop index). - unsigned index(unsigned b) const { return b / numTensors; } + /// Gets the tensor-identifier of the `TensorLoopId`. + TensorId tensor(TensorLoopId b) const { return b % numTensors; } + /// Gets the loop-identifier of the `TensorLoopId`. + LoopId loop(TensorLoopId b) const { return b / numTensors; } - /// Get the number of total loops (native loops + filter loops). - unsigned getNumLoops() const { return numLoops; } - /// Get the number of native loops. - unsigned getNumNativeLoops() const { return numNativeLoops; } - /// Get the number of filter loops. - unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; } - /// Get the starting filter loop index. - unsigned getFilterLoopStartingIdx() const { return getNumNativeLoops(); } + /// Get the total number of tensors (including the output-tensor and + /// synthetic-tensor). The result is given the type `TensorId` since + /// the result is primarily used as an upper bound for `TensorId`s. + TensorId getNumTensors() const { return numTensors; } - /// Returns true if bit corresponds to index of output tensor. - bool isOutTensor(unsigned b, unsigned i) const { - return tensor(b) == outTensor && index(b) == i; + /// Get the total number of loops (native loops + filter loops). + /// The result is given the type `LoopId` since the result will + /// generally be used as a for-loop upper bound. + LoopId getNumLoops() const { return numLoops; } + /// Get the number of native loops. The result is given the type + /// `LoopId` since the result will generally be used as a for-loop + /// upper bound. + LoopId getNumNativeLoops() const { return numNativeLoops; } + /// Get the number of filter loops. The result is given the type + /// `LoopId` since the result will generally be used as a for-loop + /// upper bound. + LoopId getNumFilterLoops() const { return numLoops - numNativeLoops; } + /// Get the identifier of the first filter-loop. + LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); } + + /// Returns true if `b` is the `i`th loop of the output tensor. + bool isOutTensor(TensorLoopId b, LoopId i) const { + assert(i < numLoops); + return b == numTensors * i + outTensor; } - /// Gets tensor ID for the output tensor. - unsigned getOutTensorID() const { return outTensor; } - /// Gets tensor ID for the synthetic tensor (used for all invariant tensor - /// expressions). - unsigned getSynTensorID() const { return syntheticTensor; } + /// Get the output tensor's identifier. + TensorId getOutTensorID() const { return outTensor; } + /// Get the synthetic tensor's identifier (used for all invariant + /// tensor expressions). + TensorId getSynTensorID() const { return syntheticTensor; } - bool isFilterLoop(unsigned ldx) const { - assert(ldx < numLoops); - return ldx >= numNativeLoops; + bool isFilterLoop(LoopId i) const { + assert(i < numLoops); + return i >= numNativeLoops; } /// Returns true if the expression is `(kTensor t)`. - bool expIsTensor(unsigned e, unsigned t) const { + bool expIsTensor(ExprId e, TensorId t) const { return tensorExps[e].kind == kTensor && tensorExps[e].tensor == t; } - /// Returns true if the expression contains the `t` as an operand. - bool expContainsTensor(unsigned e, unsigned t) const; + /// Returns true if the expression contains the tensor as an operand. + bool expContainsTensor(ExprId e, TensorId t) const; /// Returns true if the expression contains a negation on output tensor. /// I.e., `- outTensor` or `exp - outputTensor` /// NOTE: this is an trivial tests in that it does not handle recursive /// negation, i.e., it returns true when the expression is `-(-tensor)`. - bool hasNegateOnOut(unsigned e) const; + bool hasNegateOnOut(ExprId e) const; /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for /// sparse vector a. - bool isSingleCondition(unsigned t, unsigned e) const; + bool isSingleCondition(TensorId t, ExprId e) const; - /// Returns true if any set bit corresponds to sparse dimension level type. + /// Returns true if any `TensorLoopId` in the bitvector corresponds + /// to sparse level-type. bool hasAnySparse(const BitVector &bits) const; - /// Gets the dimension level type of the `t`th tensor on `i`th loop. - DimLevelType getDimLevelType(unsigned t, unsigned i) const { + /// Gets the level-type of the `t`th tensor on `i`th loop. + DimLevelType getDimLevelType(TensorId t, LoopId i) const { assert(t < numTensors && i < numLoops); - return dimTypes[t][i]; + return lvlTypes[t][i]; } - - /// Gets the dimension level type of `b`. - DimLevelType getDimLevelType(unsigned b) const { - return getDimLevelType(tensor(b), index(b)); + DimLevelType getDimLevelType(TensorLoopId b) const { + return getDimLevelType(tensor(b), loop(b)); } - std::optional getLoopIdx(unsigned t, unsigned dim) const { - assert(t < numTensors && dim < numLoops); - return dimToLoopIdx[t][dim]; + /// Gets the loop identifier for the `lvl`th level of the `t`th tensor. + std::optional getLoopId(TensorId t, Level lvl) const { + assert(t < numTensors && lvl < lvlToLoop[t].size()); + return lvlToLoop[t][lvl]; } - /// Gets the dimension number of the the `t`th tensor on `i`th loop. - std::optional getDimNum(unsigned t, unsigned i) const { + /// Gets the level number of the the `t`th tensor on `i`th loop. + std::optional getLvl(TensorId t, LoopId i) const { assert(t < numTensors && i < numLoops); - return loopIdxToDim[t][i]; + return loopToLvl[t][i]; } - - /// Gets the dimension number of `b`. - std::optional getDimNum(unsigned b) const { - return getDimNum(tensor(b), index(b)); + std::optional getLvl(TensorLoopId b) const { + return getLvl(tensor(b), loop(b)); } - /// Sets the dimension and dimension level type of the `t`th tensor on `i`th - /// loop. - void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim, - DimLevelType dlt) { - assert(isValidDLT(dlt)); - dimTypes[t][i] = dlt; - loopIdxToDim[t][i] = dim; - assert(dim < numLoops); - dimToLoopIdx[t][dim] = i; + /// Sets the level number and level-type of the `t`th tensor on + /// `i`th loop. + void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType dlt) { + assert(t < numTensors && i < numLoops && lvl < lvlToLoop[t].size() && + isValidDLT(dlt)); + lvlTypes[t][i] = dlt; + loopToLvl[t][i] = lvl; + lvlToLoop[t][lvl] = i; } - // Iterates the bits of a lattice, for each set bit, converts it into the - // corresponding tensor dimension and invokes the callback. - void foreachTidDimPairInBits( - const BitVector &bits, - function_ref dim, - DimLevelType dlt)> - cb) { - for (unsigned b : bits.set_bits()) - cb(b, tensor(b), getDimNum(b), getDimLevelType(b)); + /// Iterates over a set of `TensorLoopId`s, invoking the callback + /// for each `TensorLoopId` and passing it the corresponding tensor + /// identifier, level, and level-type. + void + foreachTensorLoopId(const BitVector &bits, + function_ref, DimLevelType)> + callback) const { + for (const TensorLoopId b : bits.set_bits()) + callback(b, tensor(b), getLvl(b), getDimLevelType(b)); } // Has sparse output tensor setter. @@ -344,64 +450,70 @@ /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in - /// "TensorExpr &te = merger.exp(e))", since insertions into the merger + /// `TensorExpr &te = merger.exp(e)`, since insertions into the merger /// may cause data movement and invalidate the underlying memory address. - TensorExp &exp(unsigned e) { return tensorExps[e]; } - LatPoint &lat(unsigned l) { return latPoints[l]; } - SmallVector &set(unsigned s) { return latSets[s]; } + TensorExp &exp(ExprId e) { return tensorExps[e]; } + LatPoint &lat(LatPointId p) { return latPoints[p]; } + SmallVector &set(LatSetId s) { return latSets[s]; } #ifndef NDEBUG /// Print methods (for debugging). - void dumpExp(unsigned e) const; - void dumpLat(unsigned p) const; - void dumpSet(unsigned s) const; + void dumpExp(ExprId e) const; + void dumpLat(LatPointId p) const; + void dumpSet(LatSetId s) const; void dumpBits(const BitVector &bits) const; #endif /// Builds the iteration lattices in a bottom-up traversal given the - /// remaining tensor (sub)expression and the next loop index in the - /// iteration graph. Returns index of the root expression. - unsigned buildLattices(unsigned e, unsigned i); + /// remaining tensor (sub)expression and the next loop in the iteration + /// graph. Returns the identifier of the root set. + LatSetId buildLattices(ExprId e, LoopId i); /// Builds a tensor expression from the given Linalg operation. - /// Returns index of the root expression on success. - std::optional buildTensorExpFromLinalg(linalg::GenericOp op); + /// On success, returns the identifier of the root expression. + std::optional buildTensorExpFromLinalg(linalg::GenericOp op); /// Rebuilds SSA format from a tensor expression. - Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, + Value buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1); private: /// Private helpers. - bool maybeZero(unsigned e) const; - bool isInvariant(unsigned e) const; - Type inferType(unsigned e, Value src); + bool maybeZero(ExprId e) const; + bool isInvariant(ExprId e) const; + Type inferType(ExprId e, Value src) const; /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. - std::optional buildTensorExp(linalg::GenericOp op, Value v); + std::optional buildTensorExp(linalg::GenericOp op, Value v); /// Merger data structures. - const unsigned outTensor; - const unsigned syntheticTensor; + const TensorId outTensor; + const TensorId syntheticTensor; const unsigned numTensors; const unsigned numNativeLoops; const unsigned numLoops; bool hasSparseOut; - // Map that converts pair to the corresponding dimension - // level type. - std::vector> dimTypes; + // TODO: Why do we use `std::vector` for `lvlTypes`, `loopToLvl`, and + // `lvlToLoop`, but use `llvm::SmallVector` for `tensorExps`, `latPoints`, + // and `latSets`? That causes a lot of confusion about whether we need + // to assert against OOB or not (since `std::vector::operator[]` + // doesn't check, but `llvm::SmallVector::operator[]` does). - // Map that converts pair to the corresponding - // dimension. - std::vector>> loopIdxToDim; + // Map that converts pair to the corresponding + // level-type. + std::vector> lvlTypes; - // Map that converts pair to the corresponding loop id. - std::vector>> dimToLoopIdx; + // Map that converts pair to the corresponding + // level. + std::vector>> loopToLvl; + + // Map that converts pair to the corresponding LoopId. + std::vector>> lvlToLoop; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; - llvm::SmallVector> latSets; + llvm::SmallVector> latSets; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -45,7 +45,7 @@ // LogicalResult initTensorExp(); - unsigned getTensorExp() const { return tensorExp; } + ExprId getTensorExp() const { return tensorExp; } linalg::GenericOp op() const { return linalgOp; } const SparsificationOptions &options() const { return sparseOptions; } @@ -65,13 +65,13 @@ // Merger delegates. // - TensorExp &exp(unsigned e) { return latticeMerger.exp(e); } - LatPoint &lat(unsigned l) { return latticeMerger.lat(l); } - SmallVector &set(unsigned s) { return latticeMerger.set(s); } - DimLevelType dlt(unsigned t, unsigned i) const { + TensorExp &exp(ExprId e) { return latticeMerger.exp(e); } + LatPoint &lat(LatPointId l) { return latticeMerger.lat(l); } + SmallVector &set(LatSetId s) { return latticeMerger.set(s); } + DimLevelType dlt(TensorId t, LoopId i) const { return latticeMerger.getDimLevelType(t, i); } - DimLevelType dlt(unsigned b) const { + DimLevelType dlt(TensorLoopId b) const { return latticeMerger.getDimLevelType(b); } @@ -81,7 +81,7 @@ /// Whether the tensor expression is admissible for codegen. /// It also sets the sparseOut if the output tensor is sparse. - bool isAdmissibleTensorExp(unsigned exp); + bool isAdmissibleTensorExp(ExprId e); /// Whether the iteration graph is sorted in admissible topoOrder. /// Sets outerParNest on success with sparse output @@ -91,17 +91,21 @@ // Topological delegate and sort methods. // - size_t topSortSize() const { return topSort.size(); } - unsigned topSortAt(unsigned i) const { return topSort.at(i); } - void topSortPushBack(unsigned i) { topSort.push_back(i); } - void topSortClear(unsigned capacity = 0) { + LoopOrd topSortSize() const { return topSort.size(); } + LoopId topSortAt(LoopOrd n) const { return topSort.at(n); } + void topSortPushBack(LoopId i) { topSort.push_back(i); } + void topSortClear(size_t capacity = 0) { topSort.clear(); topSort.reserve(capacity); } - ArrayRef getTopSortSlice(size_t n, size_t m) const; - ArrayRef getLoopCurStack() const; - Value getLoopIdxValue(size_t loopIdx) const; + ArrayRef getTopSortSlice(LoopOrd n, LoopOrd m) const; + ArrayRef getLoopStackUpTo(LoopOrd n) const; + ArrayRef getCurrentLoopStack() const; + /// Returns the induction-variable for the loop identified by the given + /// `LoopId`. This method handles application of the topological sort + /// in order to convert the `LoopId` into the corresponding `LoopOrd`. + Value getLoopVar(LoopId i) const; // // Sparse tensor output and expansion methods. @@ -113,7 +117,8 @@ Value getInsertionChain() const { return insChain; } void updateInsertionChain(Value chain); - bool atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const; + // FIXME: clarify what this "rank" is really supposed to mean/be. + bool atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const; void startExpand(Value values, Value filled, Value added, Value count); bool isExpand() const { return expValues != nullptr; } void updateExpandCount(Value count); @@ -127,8 +132,8 @@ // Reduction methods. // - void startReduc(unsigned exp, Value val); - bool isReduc() const { return redExp != -1u; } + void startReduc(ExprId exp, Value val); + bool isReduc() const { return redExp != kInvalidId; } void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); @@ -136,8 +141,8 @@ void clearValidLexInsert(); Value getValidLexInsert() const { return redValidLexInsert; } - void startCustomReduc(unsigned exp); - bool isCustomReduc() const { return redCustom != -1u; } + void startCustomReduc(ExprId exp); + bool isCustomReduc() const { return redCustom != kInvalidId; } Value getCustomRedId(); void endCustomReduc(); @@ -154,14 +159,16 @@ // Loop emitter helper class. LoopEmitter loopEmitter; - // Topological sort. - std::vector topSort; + // Topological sort. This serves as a mapping from `LoopOrd` to `LoopId` + // (cf., `getLoopVar` and `topSortAt`). + std::vector topSort; // Sparse tensor as output. Implemented either through direct injective // insertion in lexicographic index order or through access pattern // expansion in the innermost loop nest (`expValues` through `expCount`). OpOperand *sparseOut; - unsigned outerParNest; + // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`. + LoopOrd outerParNest; Value insChain; Value expValues; Value expFilled; @@ -172,8 +179,8 @@ // into the merger's expression tree. When the indices of a tensor reduction // expression are exhausted, all inner loops can use a scalarized reduction. Value redVal; - unsigned redExp; - unsigned redCustom; + ExprId redExp; + ExprId redCustom; // Bookkeeping for lex insertion during reductions. Holds the runtime boolean // value of whether any reduction occurred. This is only set during a @@ -181,7 +188,7 @@ Value redValidLexInsert; // The root tensor expression of the kernel. - unsigned tensorExp; + ExprId tensorExp; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -38,12 +38,12 @@ : linalgOp(linop), sparseOptions(opts), latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), - expFilled(), expAdded(), expCount(), redVal(), redExp(-1u), - redCustom(-1u), redValidLexInsert() {} + expFilled(), expAdded(), expCount(), redVal(), redExp(kInvalidId), + redCustom(kInvalidId), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. - std::optional optExp = latticeMerger.buildTensorExpFromLinalg(op()); + std::optional optExp = latticeMerger.buildTensorExpFromLinalg(op()); if (!optExp || !isAdmissibleTensorExp(*optExp)) return failure(); @@ -101,7 +101,7 @@ // Code generation environment verify functions. //===----------------------------------------------------------------------===// -bool CodegenEnv::isAdmissibleTensorExp(unsigned exp) { +bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) { // We reject any expression that makes a reduction from `-outTensor`, as those // expressions create a dependency between the current iteration (i) and the // previous iteration (i-1). It would require iterating over the whole @@ -115,7 +115,10 @@ } OpOperand *lhs = linalgOp.getDpsInitOperand(0); - unsigned tensor = lhs->getOperandNumber(); + // That the operand number is a valid `TensorId` will be verified + // by the call to `isSingleCondition` below; though we may want to add + // assertions to check it here, in order to give better error messages. + const TensorId tensor = lhs->getOperandNumber(); // An non-annotated output tensor is assumed dense, and becomes a random // access n-dim memref. Admissible since insertions cannot occur. if (getSparseTensorType(lhs->get()).isAllDense()) @@ -140,13 +143,14 @@ OpOperand *lhs = linalgOp.getDpsInitOperand(0); // Accept "truly dynamic" if the output tensor materializes uninitialized // into the computation and insertions occur in lexicographic index order. - unsigned nest = 0; - auto iteratorTypes = linalgOp.getIteratorTypesArray(); - for (unsigned i = 0, e = latticeMerger.getNumLoops(); i < e; i++) { - if (!latticeMerger.isFilterLoop(topSortAt(i))) { + LoopOrd nest = 0; + const auto iteratorTypes = linalgOp.getIteratorTypesArray(); + assert(topSortSize() == latticeMerger.getNumLoops()); + for (const LoopId i : topSort) { + if (!latticeMerger.isFilterLoop(i)) { // We only count non-filter loops as filter loops should be considered - // as a special type of parallel loops. - if (linalg::isReductionIterator(iteratorTypes[topSortAt(i)])) + // a special type of parallel loops. + if (linalg::isReductionIterator(iteratorTypes[i])) break; // terminate at first reduction nest++; } @@ -165,19 +169,26 @@ // Code generation environment topological sort methods //===----------------------------------------------------------------------===// -ArrayRef CodegenEnv::getTopSortSlice(size_t n, size_t m) const { - return ArrayRef(topSort).slice(n, m); +ArrayRef CodegenEnv::getTopSortSlice(LoopOrd n, LoopOrd m) const { + return ArrayRef(topSort).slice(n, m); } -ArrayRef CodegenEnv::getLoopCurStack() const { - return getTopSortSlice(0, loopEmitter.getCurrentDepth()); +ArrayRef CodegenEnv::getLoopStackUpTo(LoopOrd n) const { + return ArrayRef(topSort).take_front(n); } -Value CodegenEnv::getLoopIdxValue(size_t loopIdx) const { - for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) - if (topSort[lv] == loopIdx) - return loopEmitter.getLoopIV(lv); - llvm_unreachable("invalid loop index"); +ArrayRef CodegenEnv::getCurrentLoopStack() const { + return getLoopStackUpTo(loopEmitter.getCurrentDepth()); +} + +Value CodegenEnv::getLoopVar(LoopId i) const { + // TODO: this class should store the inverse of `topSort` so that + // it can do this conversion directly, instead of searching through + // `topSort` every time. (Or else, `LoopEmitter` should handle this.) + for (LoopOrd n = 0, numLoops = topSortSize(); n < numLoops; n++) + if (topSort[n] == i) + return loopEmitter.getLoopIV(n); + llvm_unreachable("invalid loop identifier"); } //===----------------------------------------------------------------------===// @@ -189,8 +200,10 @@ insChain = chain; } -bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, unsigned lv) const { - return sparseOut == o && outerParNest == rank - 1 && outerParNest == lv; +// FIXME: clarify what this "rank" is really supposed to mean/be. +bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopOrd n) const { + return sparseOut == o && outerParNest == static_cast(rank - 1) && + outerParNest == n; } void CodegenEnv::startExpand(Value values, Value filled, Value added, @@ -216,21 +229,21 @@ // Code generation environment reduction methods //===----------------------------------------------------------------------===// -void CodegenEnv::startReduc(unsigned exp, Value val) { - assert(redExp == -1u && exp != -1u); +void CodegenEnv::startReduc(ExprId exp, Value val) { + assert(!isReduc() && exp != kInvalidId); redExp = exp; updateReduc(val); } void CodegenEnv::updateReduc(Value val) { - assert(redExp != -1u); + assert(isReduc()); redVal = exp(redExp).val = val; } Value CodegenEnv::endReduc() { Value val = redVal; updateReduc(Value()); - redExp = -1u; + redExp = kInvalidId; return val; } @@ -244,17 +257,17 @@ redValidLexInsert = Value(); } -void CodegenEnv::startCustomReduc(unsigned exp) { - assert(redCustom == -1u && exp != -1u); +void CodegenEnv::startCustomReduc(ExprId exp) { + assert(!isCustomReduc() && exp != kInvalidId); redCustom = exp; } Value CodegenEnv::getCustomRedId() { - assert(redCustom != -1u); + assert(isCustomReduc()); return dyn_cast(exp(redCustom).op).getIdentity(); } void CodegenEnv::endCustomReduc() { - assert(redCustom != -1u); - redCustom = -1u; + assert(isCustomReduc()); + redCustom = kInvalidId; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -12,11 +12,36 @@ #include #include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/IR/PatternMatch.h" namespace mlir { namespace sparse_tensor { +//===----------------------------------------------------------------------===// +/// The position of a loop in the loop-stack, or the position of a +/// `LoopId` in a topologically-sorted list of `LoopId`s. +/// +/// Although this type may have the same cardinality as `LoopId`, it must +/// not be confused with that type. The `LoopId` type is used by the `Merger` +/// as a unique identifier for loop-variables, regardless of the ordering +/// of those loops. Whereas the `LoopOrd` type is used by the `LoopEmitter` +/// (and `CodegenEnv`) to refer to the actual order in which loops are +/// generated. +/// +/// TODO: further explicate the correspondences between these various +/// types. In particular, since the `$dim` argument to `linalg::IndexOp` +/// is a De Bruijn index, it seems like that should correspond to `LoopOrd`, +/// and yet the `Merger` has that correspond with `LoopId` instead. +/// In addition `LoopEmitter::genAffine` has `AffineDimExpr::position` +/// correspond to `LoopId`, however it is unclear what the providence +/// of those `AffineDimExpr` is. +// +// TODO: use a struct/class rather than a typedef, so that we can actually +// typecheck this to avoid mixups in the code. +using LoopOrd = size_t; + //===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to // generate loop structure to (co)-iterate sparse tensors. @@ -33,13 +58,13 @@ // // One can use // -// SparseTensorLoopEmiter loopEmiter({T1, T1}); +// LoopEmiter loopEmiter({T1, T1}); // loopEmiter.initializeLoopEmit(); -// loopEmiter.enterLoopOverTensorAtDim(T1, 0); -// loopEmiter.enterLoopOverTensorAtDim(T2, 0); -// loopEmiter.enterLoopOverTensorAtDim(T1, 1); +// loopEmiter.enterLoopOverTensorAtLvl(T1, 0); +// loopEmiter.enterLoopOverTensorAtLvl(T2, 0); +// loopEmiter.enterLoopOverTensorAtLvl(T1, 1); // loopEmiter.exitCurrentLoop(); -// loopEmiter.enterLoopOverTensorAtDim(T2, 1); +// loopEmiter.enterLoopOverTensorAtLvl(T2, 1); // loopEmiter.exitCurrentLoop(); // exit k // loopEmiter.exitCurrentLoop(); // exit j // loopEmiter.exitCurrentLoop(); // exit i @@ -54,30 +79,31 @@ LoopEmitter() = default; - /// Takes an array of tensors inputs, on which the generated loops will - /// iterate on. The index of the tensor in the array is also the tensor id - /// (tid) used in related functions. If isSparseOut is set, loop emitter - /// assume that the sparse output tensor is empty, and will always generate - /// loops on it based on the dim sizes. An optional array could be provided - /// (by sparsification) to indicate the loop id sequence that will be - /// generated. It is used to establish the mapping between affineDimExpr to - /// the corresponding loop index in the loop stack that are maintained by the - /// loop emitter. + /// Takes an array of input tensors, which the generated loops will + /// iterate over. Each tensor is given a `TensorId` (numerically equal + /// to the position of that tensor `Value` in the array). Setting + /// `isSparseOut` indicates that the sparse output tensor is empty, + /// so the loop emitter will generate loops over it according to the + /// level-sizes. The `topSort` array specifies the actual order in + /// which loops are generated, thus providing a mapping from `LoopOrd` + /// to `LoopId`. void initialize(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}); explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, bool isSparseOut = false, - ArrayRef topSort = {}); + ArrayRef topSort = {}); - /// Starts a loop emitting session by generating all the buffers needed to - /// iterate tensors. + /// Starts a loop emitting session by generating all the buffers needed + /// for iterating over the tensors. void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater = nullptr); - /// Generates a list of operations to compute the affine expression. - Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); + /// Generates code to compute an affine expression whose variables are + /// `LoopId`s (i.e., `a.cast().getPosition()` is a valid + /// `LoopId`). + Value genAffine(OpBuilder &builder, Location loc, AffineExpr a); /// Enters a new loop sequence, the loops within the same sequence starts /// from the break points of previous loop instead of starting over from 0. @@ -93,73 +119,77 @@ /// ... /// // loop sequence end. /// } - void enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims); + void enterNewLoopSeq(OpBuilder &builder, Location loc, + ArrayRef tids, ArrayRef lvls); - // exit the current loop sequence, this will reset universal index to 0. + /// Exits the current loop sequence, this will reset universal index to 0. void exitCurrentLoopSeq() { assert(loopSeqStack.size() == loopStack.size() + 1); loopSeqStack.pop_back(); } - // TODO: Gets rid of `dim` in the argument list? Track the dimension we - // are currently at internally. Then it would be enterNextDimForTensor. - // Still need a way to specify the dim for non annoated dense tensor though, - // as it can be accessed out of order. - /// Emits loop over tensor_tid_dim, it assumes that loops between - /// tensor_tid_[0, dim - 1] have already been generated. + // TODO: Get rid of `lvls` in the argument list? Track the level we + // are currently at internally. Then it would be enterNextLvlForTensor. + // Still need a way to specify the lvl for non-annotated tensors though, + // as those can be accessed out of order. + // + /// Emits loop over tensor_tid_lvl, it assumes that loops between + /// tensor_tid_[0, lvl - 1] have already been generated. /// The function will also perform in-place update on the `reduc` vector to /// return the reduction variable used inside the generated loop. - Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef dims, + Operation *enterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, + ArrayRef tids, + ArrayRef lvls, MutableArrayRef reduc = {}, bool isParallel = false); - Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc, - size_t tid, size_t dim, + Operation *enterFilterLoopOverTensorAtLvl(OpBuilder &builder, Location loc, + TensorId tid, Level lvl, AffineExpr affine, MutableArrayRef reduc = {}); - void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc, - size_t tid, size_t dim, - AffineExpr affine); + void genDenseAffineAddress(OpBuilder &builder, Location loc, TensorId tid, + Level lvl, AffineExpr lvlExpr); /// Emits a co-iteration loop over a set of tensors. - Operation *enterCoIterationOverTensorsAtDims( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, bool needsUniv, MutableArrayRef reduc = {}); + Operation *enterCoIterationOverTensorsAtLvls( + OpBuilder &builder, Location loc, ArrayRef tids, + ArrayRef lvls, bool needsUniv, MutableArrayRef reduc = {}); void exitCurrentLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc = {}); - /// Returns the array of coordinate for all the loop generated till now. - void getCoordinateArray(SmallVectorImpl &coords) const { + /// Fills the out-parameter with the loop induction variables for all + /// loops in the current loop-stack. The variables are given in the + /// same order as the loop-stack, hence `ivs` should be indexed into + /// by `LoopOrd` (not `LoopId`). + void getLoopIVs(SmallVectorImpl &ivs) const { + ivs.clear(); + ivs.reserve(getCurrentDepth()); for (auto &l : loopStack) - coords.push_back(l.iv); + ivs.push_back(l.iv); } - /// Gets loop induction variable at the given level. - unsigned getCurrentDepth() const { return loopStack.size(); } + /// Gets the current depth of the loop-stack. The result is given + /// the type `LoopOrd` for the same reason as one-past-the-end iterators. + LoopOrd getCurrentDepth() const { return loopStack.size(); } - /// Gets loop induction variable at the given level. - Value getLoopIV(size_t level) const { - if (level < loopStack.size()) - return loopStack[level].iv; - return nullptr; + /// Gets loop induction variable for the given `LoopOrd`. + Value getLoopIV(LoopOrd n) const { + return n < getCurrentDepth() ? loopStack[n].iv : Value(); } /// /// Getters. /// - const std::vector> &getPidxs() const { return pidxs; }; - const std::vector> &getCoord() const { return coord; }; + const std::vector> &getPosits() const { return posits; }; + const std::vector> &getCoords() const { return coords; }; const std::vector> &getHighs() const { return highs; }; - const std::vector> &getPosBuffer() const { - return posBuffer; + const std::vector> &getPositionBuffers() const { + return positionsBuffers; }; - const std::vector> &getCrdBuffer() const { - return crdBuffer; + const std::vector> &getCoordinateBuffers() const { + return coordinatesBuffers; }; const std::vector &getValBuffer() const { return valBuffer; }; @@ -168,57 +198,67 @@ } private: - struct LoopLevelInfo { - LoopLevelInfo(ArrayRef tids, ArrayRef dims, Operation *loop, - Block *userBlock, Value iv, StringAttr loopTag) - : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) { + struct LoopInfo { + LoopInfo(ArrayRef tids, ArrayRef lvls, Operation *loop, + Block *userBlock, Value iv, StringAttr loopTag) + : tids(tids), lvls(lvls), loop(loop), userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); } - // TODO: maybe use a vector for tid and dim? + // TODO: maybe use a vector for tid and lvl? + // (Better yet, compress them together a la `TensorLoopId`.) // The set of tensors that the loop is operating on - const llvm::SmallVector tids; - // The corresponding dims for the tensors - const llvm::SmallVector dims; + const llvm::SmallVector tids; + // The corresponding levels for the tensors + const llvm::SmallVector lvls; const Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding users' generated code. const Value iv; // the induction variable for the loop }; - /// Linearizes address for dense dimension (i.e., p = (i * d0) + j). - Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim, + /// Linearizes address for dense level (i.e., p = (i * d0) + j). + Value genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value iv); /// Generates the segment high for a non-unique level (to fast forward - /// duplicated coordinates). - Value genSegmentHigh(OpBuilder &builder, Location loc, size_t tid, size_t lvl, - Value pos, Value pHi); + /// duplicated coordinates). That is, it generates the code: + /// + /// crd = coordinates_tid_lvl[pos] + /// while (pos < pHi && coordinates_tid_lvl[pos] == crd) + /// pos++; + /// ; + Value genSegmentHigh(OpBuilder &builder, Location loc, TensorId tid, + Level lvl, Value pos, Value pHi); /// Generates instructions to compute the coordinate of tensors[tid][lvl] /// under the current loop context. The final argument is the /// collapsed-output level, whereas this function handles converting /// that to the uncollapsed-input level - Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid, - size_t dstLvl); + Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, + Level dstLvl); - bool isOutputTensor(size_t tid) { - return hasOutput && tid == tensors.size() - 1; + TensorId getNumTensors() const { return tensors.size(); } + + bool isOutputTensor(TensorId tid) const { + return hasOutput && tid == static_cast(getNumTensors() - 1); } - bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; } + bool isSparseOutput(TensorId tid) const { + return isOutputTensor(tid) && isSparseOut; + } - /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0 - /// ...dims-1] has already been setup. - void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid, - size_t dim); + /// Prepares loop for iterating over `tensor[lvl]`, under the assumption + /// that `tensor[0...lvl-1]` loops have already been set up. + void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, + TensorId tid, Level lvl); /// Emits extra locals, since the locals might not be in simplified lattices - /// point used to generate the loops, but are still required to generates + /// point used to generate the loops, but are still required to generate /// expressions. - void emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef dims); + void emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc, + ArrayRef tids, + ArrayRef lvls); /// Exits a for loop, returns the reduction results, e.g., /// For sequential for loops: @@ -251,6 +291,38 @@ void exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); + // + // View-based-reshape methods. + // + + /// Get the collapse reassociation for `tensors[tid][dstLvl]`. + /// For unreshaped operands, the reassociation is simply an identity + /// transformation. + /// + /// NOTE: the result uses `Level` rather than the `int64_t` of + /// `ReassociationIndices`, since the former gives clarity to what + /// the values actually mean. + /// + /// TODO: why not do this computation when we first store the reassoc, + /// instead of doing it every time we look it up? + SmallVector getCollapseReassociation(TensorId tid, Level dstLvl) { + assert(tid < getNumTensors() && "Invalid TensorId"); + assert(collapseReassoc.size() == getNumTensors()); + if (const auto reassoc = collapseReassoc[tid]) { + // TODO: store the dstLvlRank in the LoopEmitter so that we can + // check `dstLvl < dstLvlRank` at the top; and only here need to + // assert that `reassoc.size() == dstLvlRank`. + assert(dstLvl < reassoc.size() && "Level is out-of-bounds"); + const auto srcLvls = reassoc[dstLvl].cast(); + return llvm::to_vector<2>( + llvm::map_range(srcLvls, [&](Attribute srcLvl) -> Level { + // TODO: replace this with the converter for `LevelAttr`. + return srcLvl.cast().getValue().getZExtValue(); + })); + } + return {dstLvl}; + } + /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify /// loops that operates on sparse tensors more easily. @@ -259,64 +331,69 @@ /// tensor. bool hasOutput; bool isSparseOut; + + // + // Fields which have `numTensor` many entries. + // + // TODO: switch to an AOS style to avoid any possible mismatches. + // + /// Input and (optional) output tensors. std::vector tensors; - /// The dim type array for each tensor. - std::vector> dimTypes; - /// Sparse iteration information (by tensor and dim). These arrays - /// are updated to remain current within the current loop. - // TODO: we may want to rename "pidx(s)" to `posCursor(s)` or similar. - std::vector> pidxs; + /// Level-types for each `(TensorId, Level)` pair. + std::vector> lvlTypes; + // Sparse iteration information for each `(TensorId, Level)` pair. + // These arrays are updated to remain current within the current loop. + // TODO: Clarify which of these are indexed by dstLvl vs srcLvl. + // + /// The collection of positions for a given element (one such collection + /// for each tensor). This is the position analogue of the "coords" + /// naming convention. + /// + /// FIXME: [CLARIFY_POSITS_LVL] It's unclear which levels are used + /// to index the `posits` array. On the one hand `genSparseCrd` + /// uses dstLvl; on the other hand `enterLoopOverTensorAtLvl`, + /// `prepareLoopOverTensorAtLvl`, and `enterCoIterationOverTensorsAtLvls` + /// uses srcLvl. So which is it? + std::vector> posits; + /// The collection of coordinates for a given element (one such + /// collection for each tensor). + std::vector> coords; // The segment upper bound for non-uniques level after de-duplication. std::vector> segHi; - std::vector> coord; std::vector> highs; std::vector> lvlSizes; - std::vector> posBuffer; // to_positions - std::vector> crdBuffer; // to_coordinates - std::vector valBuffer; // to_value + std::vector> positionsBuffers; // to_positions + std::vector> coordinatesBuffers; // to_coordinates + std::vector valBuffer; // to_value /// Whether the sparse input is a slice. std::vector isSparseSlices; + /// Collapse Reassociations related to a specific tensor + // TODO: support expand. + std::vector collapseReassoc; + + /// TODO: not yet used, it should track the current level for each tensor + /// to help eliminate `lvls` paramters from above APIs. + /// std::vector curLvl; + + // + // Fields which have at most `numLoops` many entries. + // + /// Loop Stack, stores the information of all the nested loops that are /// alive. - std::vector loopStack; + std::vector loopStack; - /// Loop Sequence Stack, stores the unversial index for the current loop + /// Loop Sequence Stack, stores the universal index for the current loop /// sequence. std::vector loopSeqStack; - /// Maps AffineDimExpr to the index of the loop in loopStack. + /// Maps `LoopId` (used by `AffineDimExpr`) to `LoopOrd` (in the `loopStack`). /// TODO: We should probably use a callback function here to make it more /// general. - std::vector sparsiferLoopLvlMap; - - // - // View based reshape related-fields and methods - // - - /// Collapse Reassociations related to a specific tensor - // TODO: support expand. - std::vector collapseReassoc; - - /// Get the collapse reassociation for tensors[tid] on l. For unreshaped - /// operands, the reassociation is simply an identity transformation. - SmallVector getCollapseReassociation(unsigned tid, unsigned l) { - // Returns for SmallVector just like `ReassociaionIndices` - if (auto reass = collapseReassoc[tid]) { - auto attr = reass[l]; - return llvm::to_vector<2>( - llvm::map_range(attr.cast(), [&](Attribute indexAttr) { - return indexAttr.cast().getInt(); - })); - } - return {l}; - } - - /// TODO: not yet used, it should track the current level for each tensor - /// to help eliminate `dim` paramters from above APIs. - /// std::vector curLv; + std::vector loopIdToOrd; }; } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; @@ -45,51 +46,50 @@ // TODO: Support dynamic sized slice. static Value getSliceOffset(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { + SparseTensorEncodingAttr enc, Level lvl) { return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl)); } static Value getSliceSize(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { + SparseTensorEncodingAttr enc, Level lvl) { return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl)); } static Value getSliceStride(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { + SparseTensorEncodingAttr enc, Level lvl) { return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl)); } // Converts a coordinate relative to the slice to the coordinate relative // to the underlying tensor. -static Value toSliceCoord(OpBuilder &builder, Location loc, Value v, - SparseTensorEncodingAttr enc, unsigned lvl) { - +static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd, + SparseTensorEncodingAttr enc, Level lvl) { Value stride = getSliceStride(builder, loc, enc, lvl); Value offset = getSliceOffset(builder, loc, enc, lvl); - // iv = iv * stride + offset - v = builder.create(loc, v, stride); - v = builder.create(loc, v, offset); - return v; + // tensorCrd = sliceCrd * stride + offset + crd = builder.create(loc, crd, stride); + crd = builder.create(loc, crd, offset); + return crd; } // Converts a coordinate relative to the underlying tensor to the coordinate // relative to the slice, returns a extra reminder value static std::pair fromSliceCrd(OpBuilder &builder, Location loc, - Value v, + Value crd, SparseTensorEncodingAttr enc, - unsigned lvl) { + Level lvl) { Value stride = getSliceStride(builder, loc, enc, lvl); Value offset = getSliceOffset(builder, loc, enc, lvl); - // iv = (iv - offset) / stride - v = builder.create(loc, v, offset); - Value rem = builder.create(loc, v, stride); - v = builder.create(loc, v, stride); - return std::make_pair(v, rem); + // sliceCrd = (tensorCrd - offset) / stride + crd = builder.create(loc, crd, offset); + Value rem = builder.create(loc, crd, stride); + crd = builder.create(loc, crd, stride); + return std::make_pair(crd, rem); } static std::pair genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, - SparseTensorEncodingAttr enc, unsigned lvl) { + SparseTensorEncodingAttr enc, Level lvl) { std::pair trans = fromSliceCrd(builder, loc, crd, enc, lvl); // First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip // the check if the offset is zero). @@ -115,75 +115,75 @@ // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// -Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid, - size_t dim, Value iv) { - Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; - Value mul = builder.create(loc, highs[tid][dim], p); +Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid, + Level lvl, Value crd) { + Value pos = lvl == 0 ? constantIndex(builder, loc, 0) : posits[tid][lvl - 1]; + Value mul = builder.create(loc, highs[tid][lvl], pos); if (isSparseSlices[tid]) { auto enc = getSparseTensorEncoding(tensors[tid].getType()); - iv = toSliceCoord(builder, loc, iv, enc, dim); + crd = toSliceCrd(builder, loc, crd, enc, lvl); } - Value add = builder.create(loc, mul, iv); + Value add = builder.create(loc, mul, crd); return add; } -Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc, size_t tid, - size_t lvl, Value pos, Value pHi) { - Value prevCrd = genIndexLoad(builder, loc, crdBuffer[tid][lvl], pos); - // De-duplicates repeated elements. - // - // while (pos < pHi && coord[pos] == prev_coord) - // pos++; - // return pos; +Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc, + TensorId tid, Level lvl, Value pLo, + Value pHi) { + const auto coordinates = coordinatesBuffers[tid][lvl]; + const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo); auto whileOp = builder.create( - loc, builder.getIndexType(), pos, + loc, builder.getIndexType(), pLo, /*beforeBuilder=*/ - [this, tid, lvl, pHi, prevCrd](OpBuilder &builder, Location loc, - ValueRange ivs) { + [this, tid, lvl, pHi, coordinates, + sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) { + const auto pos = ivs[0]; Value inBound = builder.create( - loc, arith::CmpIPredicate::ult, ivs[0], pHi); - auto ifOp = + loc, arith::CmpIPredicate::ult, pos, pHi); + auto ifInBound = builder.create(loc, builder.getI1Type(), inBound, true); { OpBuilder::InsertionGuard guard(builder); // Load the next coordinates only when inbound (to avoid OOB // acccesses). - builder.setInsertionPointToStart(ifOp.thenBlock()); - Value nxCrd = genIndexLoad(builder, loc, crdBuffer[tid][lvl], ivs[0]); - Value cont = builder.create( - loc, arith::CmpIPredicate::eq, nxCrd, prevCrd); - builder.create(loc, cont); + builder.setInsertionPointToStart(ifInBound.thenBlock()); + Value crd = genIndexLoad(builder, loc, coordinates, pos); + Value isSameCrd = builder.create( + loc, arith::CmpIPredicate::eq, crd, sameCrd); + builder.create(loc, isSameCrd); // Else, the position is out of bound, yield false to terminate the // loop. - builder.setInsertionPointToStart(ifOp.elseBlock()); + builder.setInsertionPointToStart(ifInBound.elseBlock()); builder.create(loc, constantI1(builder, loc, false)); } - builder.create(loc, ifOp.getResults()[0], ivs); + builder.create(loc, ifInBound.getResults()[0], ivs); }, /*afterBuilder=*/ [](OpBuilder &builder, Location loc, ValueRange ivs) { // pos ++ - Value nxPos = builder.create( + Value nextPos = builder.create( loc, ivs[0], constantIndex(builder, loc, 1)); - builder.create(loc, nxPos); + builder.create(loc, nextPos); }); // Return the segment high. return whileOp.getResult(0); } -Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, size_t tid, - size_t dstLvl) { +Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, + Level dstLvl) { Value crd = constantIndex(builder, loc, 0); const auto reassoc = getCollapseReassociation(tid, dstLvl); - for (unsigned i = 0; i < reassoc.size(); i++) { - const auto srcLvl = reassoc[i]; + const unsigned reassocSize = reassoc.size(); + for (unsigned i = 0; i < reassocSize; i++) { + const Level srcLvl = reassoc[i]; // A load on the coordinates array yields the coordinate. - const Value mem = crdBuffer[tid][srcLvl]; - const Value pos = pidxs[tid][dstLvl]; + const Value mem = coordinatesBuffers[tid][srcLvl]; + /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. + const Value pos = posits[tid][dstLvl]; const Value off = genIndexLoad(builder, loc, mem, pos); // Linearized the coordinates within the same collapse reassociation. crd = builder.create(loc, crd, off); - if (i != reassoc.size() - 1) { + if (i != reassocSize - 1) { crd = builder.create(loc, crd, this->lvlSizes[tid][reassoc[i + 1]]); } @@ -192,33 +192,39 @@ } LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { + bool isSparseOut, ArrayRef topSort) { initialize(tensors, loopTag, hasOutput, isSparseOut, topSort); } void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput, - bool isSparseOut, ArrayRef topSort) { - // First initializes fields. + bool isSparseOut, ArrayRef topSort) { + // First initialize the top-level type of the fields. this->loopTag = loopTag; this->hasOutput = hasOutput; this->isSparseOut = isSparseOut; + + const TensorId numTensors = ts.size(); this->tensors.assign(ts.begin(), ts.end()); - this->isSparseSlices.assign(tensors.size(), false); - this->dimTypes.assign(tensors.size(), std::vector()); - this->pidxs.assign(tensors.size(), std::vector()); - this->segHi.assign(tensors.size(), std::vector()); - this->coord.assign(tensors.size(), std::vector()); - this->highs.assign(tensors.size(), std::vector()); - this->lvlSizes.assign(tensors.size(), std::vector()); - this->posBuffer.assign(tensors.size(), std::vector()); - this->crdBuffer.assign(tensors.size(), std::vector()); - this->valBuffer.assign(tensors.size(), nullptr); - this->loopStack.reserve(topSort.size()); - this->sparsiferLoopLvlMap.assign(topSort.size(), 0); - this->collapseReassoc.assign(tensors.size(), nullptr); + this->isSparseSlices.assign(numTensors, false); + this->lvlTypes.assign(numTensors, std::vector()); + this->posits.assign(numTensors, std::vector()); + this->segHi.assign(numTensors, std::vector()); + this->coords.assign(numTensors, std::vector()); + this->highs.assign(numTensors, std::vector()); + this->lvlSizes.assign(numTensors, std::vector()); + this->positionsBuffers.assign(numTensors, std::vector()); + this->coordinatesBuffers.assign(numTensors, std::vector()); + this->valBuffer.assign(numTensors, nullptr); + this->collapseReassoc.assign(numTensors, nullptr); - for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { - auto t = tensors[tid]; + const LoopOrd numLoops = topSort.size(); + this->loopIdToOrd.assign(numLoops, 0); + this->loopStack.reserve(numLoops); + // TODO: Shouldn't we `this->loopSeqStack.reserve(numLoops);` too? + + // Initialize nested types of `TensorId`-indexed fields. + for (TensorId tid = 0; tid < numTensors; tid++) { + const Value t = tensors[tid]; // a scalar or 0-dimension tensors if (isZeroRankedTensorOrScalar(t.getType())) continue; @@ -232,44 +238,49 @@ collapseReassoc[tid] = reshape.getReassociation(); rtp = reshape.getSrcType(); // Overwrites the tensor to the source tensor of reshape operations. - tensors[tid] = t = reshape.getSrc(); + tensors[tid] = reshape.getSrc(); } - auto rank = static_cast(rtp.getRank()); - auto enc = getSparseTensorEncoding(rtp); + const SparseTensorType stt(rtp); + const Level lvlRank = stt.getLvlRank(); // We always treat sparse output tensor as dense so that we always iterate - // it based on dim size. - if (enc && !(isOutputTensor(tid) && isSparseOut)) { + // it based on lvl size. + if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { + const auto enc = stt.getEncoding(); isSparseSlices[tid] = enc.isSlice(); - for (auto dimTp : enc.getDimLevelType()) - dimTypes[tid].push_back(dimTp); - } else - dimTypes[tid].assign(rank, DimLevelType::Dense); + for (auto dlt : enc.getDimLevelType()) + lvlTypes[tid].push_back(dlt); + } else { + lvlTypes[tid].assign(lvlRank, DimLevelType::Dense); + } // Initialize using empty value. - pidxs[tid].assign(rank, Value()); - segHi[tid].assign(rank, Value()); - coord[tid].assign(rank, Value()); - highs[tid].assign(rank, Value()); - lvlSizes[tid].assign(rank, Value()); - posBuffer[tid].assign(rank, Value()); - crdBuffer[tid].assign(rank, Value()); + posits[tid].assign(lvlRank, Value()); + segHi[tid].assign(lvlRank, Value()); + coords[tid].assign(lvlRank, Value()); + highs[tid].assign(lvlRank, Value()); + lvlSizes[tid].assign(lvlRank, Value()); + positionsBuffers[tid].assign(lvlRank, Value()); + coordinatesBuffers[tid].assign(lvlRank, Value()); } + // Construct the inverse of the `topSort` from the sparsifier. + // This is needed to map `AffineDimExpr`s back to the `LoopOrd` + // used in loop emitter. // FIXME: This map should be maintained outside loop emitter. - for (unsigned i = 0, e = topSort.size(); i < e; i++) { - // This is an inverse map of the topologically sorted loop index from - // sparsifier. This is needed to map the AffineDimExpr back to the loopStack - // index used in loop emitter. - sparsiferLoopLvlMap[topSort[i]] = i; - } + for (LoopOrd n = 0; n < numLoops; n++) + loopIdToOrd[topSort[n]] = n; } void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater) { - // For every tensor, find lower and upper bound on dimensions, set the - // same bounds on loop indices, and obtain dense or sparse buffer(s). - for (size_t t = 0, e = tensors.size(); t < e; t++) { - const auto tensor = tensors[t]; + // For every tensor: + // * get the values buffer. + // * For every level: + // * get the positions and coordinates buffers + // * get/compute the level-size, which is also used as the upper-bound + // on positions. + for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { + const Value tensor = tensors[t]; const auto rtp = tensor.getType().dyn_cast(); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and @@ -285,22 +296,25 @@ // Scan all levels of current tensor. for (Level l = 0; l < lvlRank; l++) { // This should be called only once at beginning. - assert(!posBuffer[t][l] && !crdBuffer[t][l] && !highs[t][l]); - const auto dlt = dimTypes[t][l]; + assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] && + !highs[t][l]); + const auto dlt = lvlTypes[t][l]; // Handle sparse storage schemes. if (isCompressedDLT(dlt)) { - // Generate sparse primitives to obtains positions and coordinates. - posBuffer[t][l] = genToPositions(builder, loc, tensor, l); - crdBuffer[t][l] = genToCoordinates(builder, loc, tensor, l, cooStart); + // Generate sparse primitives to obtain positions and coordinates. + positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l); + coordinatesBuffers[t][l] = + genToCoordinates(builder, loc, tensor, l, cooStart); } else if (isSingletonDLT(dlt)) { // Singleton level, fetch coordinates. - crdBuffer[t][l] = genToCoordinates(builder, loc, tensor, l, cooStart); + coordinatesBuffers[t][l] = + genToCoordinates(builder, loc, tensor, l, cooStart); } else { // Dense level, nothing to fetch. assert(isDenseDLT(dlt)); } - // Find upper bound in current dimension. + // Find upper bound in current level. // FIXME: `toOrigDim` is deprecated const Dimension d = toOrigDim(enc, l); lvlSizes[t][l] = highs[t][l] = @@ -332,44 +346,49 @@ valBuffer[t] = denseVal; } else { // Annotated sparse tensors. - // We also need the value buffer for annotated all dense `sparse` tensor. + // We also need the value buffer for all-dense annotated "sparse" tensors. valBuffer[t] = genToValues(builder, loc, tensor); } - // NOTE: we can also prepare for 0 dim here in advance, this will hosit + // NOTE: we can also prepare for 0 lvl here in advance, this will hoist // some loop preparation from tensor iteration, but will also (undesirably) - // hosit the code ouside if conditions. + // hoist the code ouside if-conditions. } } void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef dims) { + ArrayRef tids, + ArrayRef lvls) { // TODO: sort assert(loopSeqStack.size() == loopStack.size()); // Universal Index starts from 0. loopSeqStack.emplace_back(constantIndex(builder, loc, 0)); // Prepares for all the tensors used in the current loop sequence. - for (auto [tid, dim] : llvm::zip(tids, dims)) - prepareLoopOverTensorAtDim(builder, loc, tid, dim); + assert(tids.size() == lvls.size()); + for (auto [tid, lvl] : llvm::zip(tids, lvls)) + prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); } -Value LoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, Location loc) { +Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { switch (a.getKind()) { case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - return loopStack[sparsiferLoopLvlMap[idx]].iv; + // FIXME: since the one callsite in Sparsification passes in a + // level-expression, the `getPosition` must in fact be a `Dimension`. + // However, elsewhere we have been lead to expect that `loopIdToOrd` + // should be indexed by `LoopId`... + const LoopId i = a.cast().getPosition(); + return loopStack[loopIdToOrd[i]].iv; } case AffineExprKind::Add: { auto binOp = a.cast(); return builder.create( - loc, genAffine(builder, binOp.getLHS(), loc), - genAffine(builder, binOp.getRHS(), loc)); + loc, genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Mul: { auto binOp = a.cast(); return builder.create( - loc, genAffine(builder, binOp.getLHS(), loc), - genAffine(builder, binOp.getRHS(), loc)); + loc, genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); @@ -380,39 +399,49 @@ } } -Operation *LoopEmitter::enterLoopOverTensorAtDim( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, MutableArrayRef reduc, bool isParallel) { +Operation *LoopEmitter::enterLoopOverTensorAtLvl( + OpBuilder &builder, Location loc, ArrayRef tids, + ArrayRef lvls, MutableArrayRef reduc, bool isParallel) { // TODO: support multiple return on parallel for? assert(!isParallel || reduc.size() <= 1); bool isSparseInput = false; - size_t tid = tids.front(), dim = dims.front(); - for (auto [t, d] : llvm::zip(tids, dims)) { - assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair - assert(!coord[t][d]); // We cannot re-enter the same level - auto dimType = dimTypes[t][d]; + TensorId tid = tids.front(); + Level dstLvl = lvls.front(); + assert(tids.size() == lvls.size()); + for (auto [t, l] : llvm::zip(tids, lvls)) { + // TODO: this check for validity of the (t,l) pairs should be + // checked/enforced at the callsites, if possible. + assert(t < lvlTypes.size() && l < lvlTypes[t].size()); + assert(!coords[t][l]); // We cannot re-enter the same level + const auto dlt = lvlTypes[t][l]; + const bool isSparse = isCompressedDLT(dlt) || isSingletonDLT(dlt); // Must be a recognizable DLT. - assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || - isSingletonDLT(dimType)); - bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType); + assert(isSparse || isDenseDLT(dlt)); // We can at most have one sparse input, otherwise, a while loop is required // to co-iterate multiple sparse tensors. assert(!isSparseInput || !isSparse); if (isSparse) { tid = t; - dim = d; + dstLvl = l; } isSparseInput = isSparseInput || isSparse; } - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - const auto reassoc = getCollapseReassociation(tid, dim); - dim = reassoc.front(); + const auto enc = getSparseTensorEncoding(tensors[tid].getType()); + const auto reassoc = getCollapseReassociation(tid, dstLvl); + // FIXME(wrengr): REBASE START: I made a bunch of changes below, but they need + // to be rebased against + // + // + // Use the first source-level here to build the loop bound (which is + // also the biggest range). + const Level srcLvl = reassoc.front(); // TODO: support dynamic slices. - Value step = constantIndex(builder, loc, 1); - Value lo = isSparseInput ? pidxs[tid][dim] // current offset - : loopSeqStack.back(); // universal index - Value hi = highs[tid][dim]; + const Value step = constantIndex(builder, loc, 1); + /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. + const Value lo = isSparseInput ? posits[tid][srcLvl] // current position + : loopSeqStack.back(); // universal index + const Value hi = highs[tid][srcLvl]; Operation *loop = nullptr; Value iv; @@ -426,7 +455,7 @@ // In-place update on the reduction variable vector. // Note that the init vals is not the actual reduction variables but instead - // used as a `special handle` to (temporarily) represent them. The + // used as a "special handle" to (temporarily) represent them. The // expression on init vals will be moved into scf.reduce and replaced with // the block arguments when exiting the loop (see exitForLoop). This is // needed as we can not build the actual reduction block and get the actual @@ -451,9 +480,10 @@ if (isSparseInput) { assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); // For COO, the position is the same across consecutive levels. + /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. llvm::for_each(reassoc, - [this, tid, iv](Level lvl) { pidxs[tid][lvl] = iv; }); - crd = genSparseCrd(builder, loc, tid, dim); + [this, tid, iv](Level srcLvl) { posits[tid][srcLvl] = iv; }); + crd = genSparseCrd(builder, loc, tid, srcLvl); } else { // Dense tensor, the coordinate is the inducation variable. crd = iv; @@ -466,7 +496,7 @@ for (Value red : reduc) types.push_back(red.getType()); - auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim); + auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, srcLvl); bool hasReduc = !types.empty(); scf::IfOp ifOp = builder.create(loc, types, pred, /*else*/ hasReduc); @@ -488,35 +518,33 @@ } assert(crd); - coord[tid][dim] = crd; - // NOTE: we can also prepare for next dim here in advance + coords[tid][srcLvl] = crd; + // NOTE: we can also prepare for next level here in advance // Push the loop into stack - loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), loop, - builder.getInsertionBlock(), coord[tid][dim], loopTag); + loopStack.emplace_back(ArrayRef(tid), ArrayRef(srcLvl), loop, + builder.getInsertionBlock(), crd, loopTag); // Emit extra locals. - emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); + emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls); return loop; } -Operation *LoopEmitter::enterFilterLoopOverTensorAtDim( - OpBuilder &builder, Location loc, size_t tid, size_t dim, AffineExpr affine, - MutableArrayRef reduc) { - assert(!affine.isa() && !isDenseDLT(dimTypes[tid][dim])); - assert(dimTypes[tid].size() > dim); +Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl( + OpBuilder &builder, Location loc, TensorId tid, Level lvl, + AffineExpr affine, MutableArrayRef reduc) { + assert(tid < lvlTypes.size() && lvl < lvlTypes[tid].size()); + assert(!affine.isa() && !isDenseDLT(lvlTypes[tid][lvl])); // We can not re-enter the same level. - assert(!coord[tid][dim]); - - Value step = constantIndex(builder, loc, 1); - - Value lo = pidxs[tid][dim]; - Value hi = highs[tid][dim]; + assert(!coords[tid][lvl]); // TODO: We should instead use a whileOp for filter loop to allow early - // break when exceeding (for ordered dimensions). + // break when exceeding (for ordered levels). // TODO: There are many other potiential opportunities that we might apply in - // the future. E.g., we could use binary search to located the position index. - scf::ForOp forOp = builder.create(loc, lo, hi, step, reduc); + // the future. E.g., we could use binary search to locate positions. + const Value step = constantIndex(builder, loc, 1); + const Value pLo = posits[tid][lvl]; + const Value pHi = highs[tid][lvl]; + scf::ForOp forOp = builder.create(loc, pLo, pHi, step, reduc); // In-place update on the reduction variable vector. assert(forOp.getNumRegionIterArgs() == reduc.size()); @@ -524,18 +552,19 @@ reduc[i] = forOp.getRegionIterArg(i); builder.setInsertionPointToStart(forOp.getBody()); - Value iv = forOp.getInductionVar(); - - pidxs[tid][dim] = iv; - // Generating a load on the coordinates array yields the coordinate. - Value mem = crdBuffer[tid][dim]; - coord[tid][dim] = genIndexLoad(builder, loc, mem, iv); + // The induction variable gives the position. + const Value pos = forOp.getInductionVar(); + posits[tid][lvl] = pos; + // Generating a load on the coordinates array yields the crd. + const Value mem = coordinatesBuffers[tid][lvl]; + const Value crd = genIndexLoad(builder, loc, mem, pos); + coords[tid][lvl] = crd; // Generate an if-condition to filter out coordinates that are not // equal to the result of the affine expression. - Value expected = genAffine(builder, affine, loc); - auto pred = builder.create(loc, arith::CmpIPredicate::eq, - coord[tid][dim], expected); + Value expected = genAffine(builder, loc, affine); + auto pred = builder.create(loc, arith::CmpIPredicate::eq, crd, + expected); SmallVector types; for (Value red : reduc) { types.push_back(red.getType()); @@ -559,35 +588,49 @@ // Set the insert point to matched branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - // NOTE: we can also prepare for next dim here in advance + // NOTE: we can also prepare for next lvl here in advance // Push the loop into stack - loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), forOp, - builder.getInsertionBlock(), coord[tid][dim], nullptr); + loopStack.emplace_back(ArrayRef(tid), ArrayRef(lvl), forOp, + builder.getInsertionBlock(), crd, nullptr); return forOp; } -void LoopEmitter::genDenseAffineAddressAtCurLevel(OpBuilder &builder, - Location loc, size_t tid, - size_t dim, - AffineExpr affine) { - Value affineV = genAffine(builder, affine, loc); - pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV); +void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc, + TensorId tid, Level lvl, + AffineExpr lvlExpr) { + assert(isDenseDLT(lvlTypes[tid][lvl])); + // For dense levels, the level-coordinate also serves as the position. + Value lvlCrd = genAffine(builder, loc, lvlExpr); + posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd); } -Operation *LoopEmitter::enterCoIterationOverTensorsAtDims( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, bool needsUniv, MutableArrayRef reduc) { - assert(tids.size() == dims.size()); +Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( + OpBuilder &builder, Location loc, ArrayRef tids, + ArrayRef lvls, bool needsUniv, MutableArrayRef reduc) { + assert(tids.size() == lvls.size()); SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each coordinate. - Type indexType = builder.getIndexType(); - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (isCompressedDLT(dimTypes[tid][dim]) || - isSingletonDLT(dimTypes[tid][dim])) { - assert(pidxs[tid][dim]); + const Type indexType = builder.getIndexType(); + for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + const auto dlt = lvlTypes[tid][lvl]; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + // FIXME(wrengr): REBASE START: backporting stuff from + // + const auto reassoc = getCollapseReassociation(tid, lvl); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { + // This is the segment high for each non-unique levels. + types.push_back(indexType); + operands.push_back(constantIndex(builder, loc, 0)); + } + } + // FIXME(wrengr): REBASE END: backporting stuff from + // + const auto pos = posits[tid][reassoc.front()]; + assert(pos); types.push_back(indexType); - operands.push_back(pidxs[tid][dim]); + operands.push_back(pos); } } // The position where user-supplied reduction variable starts. @@ -612,22 +655,35 @@ builder.setInsertionPointToStart(&whileOp.getBefore().front()); Value cond; unsigned o = 0; - for (auto [t, lvl] : llvm::zip(tids, dims)) { + for (auto [t, lvl] : llvm::zip(tids, lvls)) { unsigned tid = t; // Why `t` can not be captured by lambda? - if (isCompressedDLT(dimTypes[tid][lvl]) || - isSingletonDLT(dimTypes[tid][lvl])) { + const auto dlt = lvlTypes[tid][lvl]; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + // FIXME(wrengr): REBASE START: backporting stuff from + // + const auto reassoc = getCollapseReassociation(tid, lvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { + // Links the SSA chain for segHi. + segHi[tid][reassoc[i]] = after->getArgument(o++); + } + } Value op1 = before->getArgument(o); - Value op2 = highs[tid][lvl]; + // We used the first level bound as the bound the collapsed set of levels. + Value op2 = highs[tid][reassoc.front()]; + // FIXME(wrengr): REBASE END: backporting stuff from + // Value opc = builder.create(loc, arith::CmpIPredicate::ult, op1, op2); cond = cond ? builder.create(loc, cond, opc) : opc; // Update positions Value pos = after->getArgument(o++); - const auto reassoc = getCollapseReassociation(tid, lvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); // For COO, the position is the same across consecutive levels. - llvm::for_each(reassoc, - [this, tid, pos](Level lvl) { pidxs[tid][lvl] = pos; }); + /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. + llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) { + posits[tid][srcLvl] = pos; + }); } } builder.create(loc, cond, before->getArguments()); @@ -637,20 +693,20 @@ SmallVector> slicesPreds; unsigned i = 0; - for (auto [tid, dim] : llvm::zip(tids, dims)) { + for (auto [tid, lvl] : llvm::zip(tids, lvls)) { // Prepares for next level. - if (isCompressedDLT(dimTypes[tid][dim]) || - isSingletonDLT(dimTypes[tid][dim])) { - coord[tid][dim] = genSparseCrd(builder, loc, tid, dim); + const auto dlt = lvlTypes[tid][lvl]; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); if (isSparseSlices[tid]) { - Value load = - genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]); + Value load = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], + posits[tid][lvl]); auto enc = getSparseTensorEncoding(tensors[tid].getType()); auto [trans, pred] = - genSliceLegitPredicate(builder, loc, load, enc, dim); + genSliceLegitPredicate(builder, loc, load, enc, lvl); slicesPreds.emplace_back(pred, i); // Updates to the relative coordinate to the slice. - coord[tid][dim] = trans; + coords[tid][lvl] = trans; } i++; } @@ -660,14 +716,16 @@ // Skips invalid loop iteration when slice coordinate is inapplicable. SmallVector yields(after->getArguments()); // Generates a list of if statments - // pidx = in_slice ? pidx : pidx + 1 - // TODO: instead of always picking pidx + 1, we should set pidx = high to + // pos = in_slice ? pos : pos + 1 + // TODO: instead of always picking pos + 1, we should set pos = high to // break to loop the coordinates is larger than the slice size. + // + // This "idx" is the index into `llvm::zip(tids, lvls)` for (auto [pred, idx] : slicesPreds) { - Value nextPidx = builder.create( + Value nextPos = builder.create( loc, yields[idx], constantIndex(builder, loc, 1)); yields[idx] = - builder.create(loc, pred, yields[idx], nextPidx); + builder.create(loc, pred, yields[idx], nextPos); } Value pred = slicesPreds.front().first; @@ -690,38 +748,83 @@ Value min; // Finds the minimum coordinate if (!needsUniv) { - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (isCompressedDLT(dimTypes[tid][dim]) || - isSingletonDLT(dimTypes[tid][dim])) { + for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + const auto dlt = lvlTypes[tid][lvl]; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + const auto crd = coords[tid][lvl]; if (min) { Value cmp = builder.create( - loc, arith::CmpIPredicate::ult, coord[tid][dim], min); - min = builder.create(loc, cmp, coord[tid][dim], min); + loc, arith::CmpIPredicate::ult, crd, min); + min = builder.create(loc, cmp, crd, min); } else { - min = coord[tid][dim]; + min = crd; } } } } else { assert(!min); - // Otherwise, universal index is the minimal pidx. + // Otherwise, universal index is the minimal pos. min = after->getArguments().back(); } // Sets up the loop stack. - loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min, + loopStack.emplace_back(tids, lvls, whileOp, builder.getInsertionBlock(), min, loopTag); assert(loopStack.size() == loopSeqStack.size()); - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (!isUniqueDLT(dimTypes[tid][dim])) { - segHi[tid][dim] = genSegmentHigh(builder, loc, tid, dim, pidxs[tid][dim], - highs[tid][dim]); + for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) { + // FIXME(wrengr): REBASE START: backporting stuff from + // + const auto reassoc = getCollapseReassociation(tid, dstLvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + // TODO: Refactors this into smaller functions. + // NOTE: For all the collapsed level (except for the last one, that is why + // the loop ends with `reassoc.size() - 1`), as each iteration is advanced + // by the segment size of the last level, which does not always invalidate + // the segment size for the previous levels, thus we need to propagate the + // segment sizes across loop iterations and only forward if needed. + // + // E.g., for a COO tensor with the following coordinates array. + // (0, 0, 1), + // (0, 0, 2), + // (1, 1, 1), + // segHi[lvl=0] = segHi[lvl=1] = 2 + // segHi[lvl=2] = 1, + // the first iteration does not invalidate segHi[0] and segHi[1] + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + const Level srcLvl = reassoc[i]; + if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { + const Value pos = posits[tid][srcLvl]; + const auto oldSegHi = segHi[tid][srcLvl]; + assert(oldSegHi); + Value newSegHi = builder.create( + loc, arith::CmpIPredicate::uge, pos, oldSegHi); + auto ifNewSegHi = builder.create(loc, builder.getIndexType(), + newSegHi, true); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(ifNewSegHi.thenBlock()); + builder.create(loc, + genSegmentHigh(builder, loc, tid, srcLvl, + pos, highs[tid][srcLvl])); + // Else, resues the same segment high. + builder.setInsertionPointToStart(ifNewSegHi.elseBlock()); + builder.create(loc, oldSegHi); + } + highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0); + } + }; + const auto srcLvl = reassoc.back(); + if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { + segHi[tid][srcLvl] = genSegmentHigh( + builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]); } + // FIXME(wrengr): REBASE END: backporting stuff from + // } // Emits extra locals - emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); + emitExtraLocalsForTensorsAtDenseLvls(builder, loc, tids, lvls); // Updates reduction variables assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0)); @@ -732,74 +835,75 @@ return whileOp; } -void LoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, - size_t tid, size_t dim) { - assert(dimTypes[tid].size() > dim); - auto dimType = dimTypes[tid][dim]; +void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, + TensorId tid, Level dstLvl) { + assert(tid < lvlTypes.size() && dstLvl < lvlTypes[tid].size()); + const auto dlt = lvlTypes[tid][dstLvl]; - if (isDenseDLT(dimType)) + if (isDenseDLT(dlt)) return; - for (auto lvl : getCollapseReassociation(tid, dim)) { + const Value c0 = constantIndex(builder, loc, 0); + const Value c1 = constantIndex(builder, loc, 1); + for (const Level srcLvl : getCollapseReassociation(tid, dstLvl)) { // Either the first level, or the previous level has been set. - assert(lvl == 0 || pidxs[tid][lvl - 1]); - Value c0 = constantIndex(builder, loc, 0); - Value c1 = constantIndex(builder, loc, 1); - if (isCompressedDLT(dimType)) { - Value mem = posBuffer[tid][lvl]; + /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. + assert(srcLvl == 0 || posits[tid][srcLvl - 1]); + if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) + continue; + if (isCompressedDLT(dlt)) { + const Value mem = positionsBuffers[tid][srcLvl]; - Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1]; - pidxs[tid][lvl] = genIndexLoad(builder, loc, mem, pLo); + const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; + posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); - Value pHi = builder.create(loc, pLo, c1); - highs[tid][lvl] = genIndexLoad(builder, loc, mem, pHi); + const Value pHi = builder.create(loc, pLo, c1); + highs[tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi); return; } - if (isSingletonDLT(dimType)) { - Value pLo = lvl == 0 ? c0 : pidxs[tid][lvl - 1]; - Value pHi; - // If this is non-unique, the pHi is bound by the segment high of the - // previous level. - if (!isUniqueDLT(dimTypes[tid][lvl - 1])) - pHi = segHi[tid][lvl - 1]; + if (isSingletonDLT(dlt)) { + const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; + posits[tid][srcLvl] = pLo; - // If pHi is still uninitialized, we set it to one as it is a singleton - // level. - // NOTE: Even if the level is non-unique, the pHi might not have been set - // in the previous statement, as we only compute segment high when we are - // coiterating non-unique levels. - if (!pHi) - pHi = builder.create(loc, pLo, c1); - pidxs[tid][lvl] = pLo; - highs[tid][lvl] = pHi; + // If we are coiterating non-unique levels, then use pHi=segHi; + // otherwise use pHi=pLo+1. + // NOTE: Just because the level is non-unique, that does not + // guarantee that segHi is defined: because we only generate segHi + // whenever coiterating, in order to improve code quality for the + // non-coiterating cases. + const auto theSegHi = segHi[tid][srcLvl - 1]; + highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && theSegHi) + ? theSegHi + : builder.create(loc, pLo, c1); return; } } - llvm_unreachable("Unrecognizable dimesion type!"); + llvm_unreachable("Unrecognized level-type!"); } -void LoopEmitter::emitExtraLocalsForTensorsAtDenseDims(OpBuilder &builder, +void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(OpBuilder &builder, Location loc, - ArrayRef tids, - ArrayRef dims) { + ArrayRef tids, + ArrayRef lvls) { // Initialize dense positions. Note that we generate dense coordinates of the // output tensor unconditionally, since they may not appear in the lattice, // but may be needed for linearized codegen. - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (isDenseDLT(dimTypes[tid][dim])) { + assert(tids.size() == lvls.size()); + for (auto [tid, lvl] : llvm::zip(tids, lvls)) { + if (isDenseDLT(lvlTypes[tid][lvl])) { auto enc = getSparseTensorEncoding(tensors[tid].getType()); if (enc && !isSparseOutput(tid)) { - bool validPidx = dim == 0 || pidxs[tid][dim - 1]; - if (!validPidx) { - // We might not find the pidx for the sparse output tensor as it is + bool validPos = lvl == 0 || posits[tid][lvl - 1]; + if (!validPos) { + // We might not find the pos for the sparse output tensor as it is // unconditionally required by the sparsification. assert(isOutputTensor(tid)); continue; } - pidxs[tid][dim] = - genAddress(builder, loc, tid, dim, loopStack.back().iv); - // NOTE: we can also prepare for next dim here in advance + posits[tid][lvl] = + genAddress(builder, loc, tid, lvl, loopStack.back().iv); + // NOTE: we can also prepare for next lvl here in advance } } } @@ -807,12 +911,9 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc) { - LoopLevelInfo &loopInfo = loopStack.back(); + const LoopInfo &loopInfo = loopStack.back(); rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); - auto &dims = loopStack.back().dims; - auto &tids = loopStack.back().tids; - auto forOp = llvm::dyn_cast(loopInfo.loop); - if (forOp) { + if (auto forOp = llvm::dyn_cast(loopInfo.loop)) { if (!reduc.empty()) { assert(reduc.size() == forOp.getNumResults()); rewriter.create(loc, reduc); @@ -843,6 +944,7 @@ // One of the operands must be the init value (which is also the // previous reduction value). assert(curVal); +#ifndef NDEBUG // The reduction expression should be the only user of the reduction val // inside the parallel for. unsigned numUsers = 0; @@ -851,7 +953,7 @@ numUsers++; } assert(numUsers == 1); - (void)numUsers; // to silence unused variable warning in release build +#endif // NDEBUG rewriter.setInsertionPointAfter(redExp); auto redOp = rewriter.create(loc, curVal); @@ -877,23 +979,21 @@ // Finished iterating a tensor, clean up // We only do the clean up on for loop as while loops do not necessarily // finish the iteration on a sparse tensor - for (auto [tid, dim] : llvm::zip(tids, dims)) { + for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { // Reset to null. - coord[tid][dim] = Value(); - pidxs[tid][dim] = Value(); - // Dense dimension, high is fixed. - if (!isDenseDLT(dimTypes[tid][dim])) - highs[tid][dim] = Value(); + coords[tid][lvl] = Value(); + posits[tid][lvl] = Value(); + // Dense level, high is fixed. + if (!isDenseDLT(lvlTypes[tid][lvl])) + highs[tid][lvl] = Value(); } } void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc) { - const LoopLevelInfo &loopInfo = loopStack.back(); + const LoopInfo &loopInfo = loopStack.back(); auto whileOp = llvm::cast(loopInfo.loop); builder.setInsertionPointToEnd(loopInfo.userCodeBlock); - auto &dims = loopInfo.dims; - auto &tids = loopInfo.tids; Value iv = loopInfo.iv; // Finalize the induction. Note that the induction could be performed // in the individual if-branches to avoid re-evaluating the conditions. @@ -903,27 +1003,47 @@ unsigned o = 0; SmallVector operands; Value one = constantIndex(builder, loc, 1); - for (auto [tid, dim] : llvm::zip(tids, dims)) { - if (isCompressedDLT(dimTypes[tid][dim]) || - isSingletonDLT(dimTypes[tid][dim])) { - Value op1 = coord[tid][dim]; - Value op3 = pidxs[tid][dim]; + for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { + const auto dlt = lvlTypes[tid][dstLvl]; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + // FIXME(wrengr): REBASE START: backporting stuff from + // + const auto reassoc = getCollapseReassociation(tid, dstLvl); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { + const Level srcLvl = reassoc[i]; + if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { + operands.push_back(segHi[tid][srcLvl]); + o++; + } + } + const Value crd = coords[tid][dstLvl]; + const Value pos = posits[tid][dstLvl]; Value cmp = - builder.create(loc, arith::CmpIPredicate::eq, op1, iv); + builder.create(loc, arith::CmpIPredicate::eq, crd, iv); // If the loop contains a coiteration with non-unique level, we fast // forward all the duplicated coords by setting the position to the // segment high. - Value add = !isUniqueDLT(dimTypes[tid][dim]) - ? segHi[tid][dim] - : builder.create(loc, op3, one); - operands.push_back(builder.create(loc, cmp, add, op3)); + Value add = !isUniqueDLT(lvlTypes[tid][reassoc.back()]) + ? segHi[tid][dstLvl] + : builder.create(loc, pos, one); + operands.push_back(builder.create(loc, cmp, add, pos)); // Following loops continue iteration from the break point of the // current while loop. - pidxs[tid][dim] = whileOp->getResult(o++); - // The coordinates are invalid now. - coord[tid][dim] = nullptr; - // The segment high are invalid now - segHi[tid][dim] = nullptr; + const Value newPos = whileOp->getResult(o++); + // We need to define a new local variable for `tid` to avoid + // warnings about "captured structured bindings are a C++20 extension". + // FIXME(wrengr): define a helper function to capture this idiom! + const TensorId newTid = tid; + llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) { + posits[newTid][srcLvl] = newPos; + }); + // FIXME(wrengr): REBASE END: backporting stuff from + // + // The coordinate is invalid now. + coords[tid][dstLvl] = nullptr; + // The segment high is invalid now. + segHi[tid][dstLvl] = nullptr; // highs remains unchanged. } } @@ -953,8 +1073,8 @@ MutableArrayRef reduc) { // Clean up the values, it would help use to discover potential bug at a // earlier stage (instead of silently using a wrong value). - LoopLevelInfo &loopInfo = loopStack.back(); - assert(loopInfo.tids.size() == loopInfo.dims.size()); + const LoopInfo &loopInfo = loopStack.back(); + assert(loopInfo.tids.size() == loopInfo.lvls.size()); SmallVector red; if (llvm::isa(loopInfo.loop)) { exitCoIterationLoop(rewriter, loc, reduc); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -932,20 +932,21 @@ ValueRange{input}, StringAttr::get(getContext(), ForeachOp::getOperationName())); loopEmitter.initializeLoopEmit(rewriter, loc); - for (Dimension d = 0; d < dimRank; d++) { + for (Level l = 0; l < lvlRank; l++) { // TODO: provide utility function for loop sequences that only contains // one for loop? - const Level l = op.getOrder() ? op.getOrder()->getDimPosition(d) : d; - loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(l)); + // FIXME(wrengr): what is this "ld" supposed to be really? + const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l; + loopEmitter.enterNewLoopSeq(rewriter, loc, 0, ld); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, l, reduc); + loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, 0, l, reduc); } SmallVector lcvs; lcvs.reserve(lvlRank); - loopEmitter.getCoordinateArray(lcvs); + loopEmitter.getLoopIVs(lcvs); if (op.getOrder()) { // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank` @@ -956,10 +957,10 @@ } } Value vals = loopEmitter.getValBuffer()[0]; - Value pidx = loopEmitter.getPidxs()[0].back(); + Value pos = loopEmitter.getPosits()[0].back(); // Loads the value from sparse tensor using position-index; // loads the value from dense tensor using coords. - Value val = enc ? rewriter.create(loc, vals, pidx) + Value val = enc ? rewriter.create(loc, vals, pos) : rewriter.create(loc, vals, lcvs); // 2. Inline the block in the foreach operator. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -99,7 +99,9 @@ AffineDimExpr getDimExpr() const { return pickedDim.cast(); } private: - /// The picked AffineDimExpr after visit. + /// The picked AffineDimExpr after visit. This must be stored as + /// `AffineExpr` rather than `AffineDimExpr`, because the latter + /// doesn't have a default ctor. AffineExpr pickedDim; /// The iterator type that we want. utils::IteratorType pickIterType; @@ -113,20 +115,25 @@ // Sparse compiler analysis methods. //===----------------------------------------------------------------------===// +// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory, +// and those letters are too easy to confuse visually. We should switch +// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop" +// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention). + /// Determines if affine expression is invariant. -static bool isInvariantAffine(AffineExpr a, ArrayRef loopStack, - unsigned ldx, bool &atLevel) { +static bool isInvariantAffine(AffineExpr a, ArrayRef loopStack, + LoopId ldx, bool &isAtLoop) { switch (a.getKind()) { case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - if (idx == ldx) { - atLevel = true; - // Must be invariant if we are at the level. + const LoopId i = a.cast().getPosition(); + if (i == ldx) { + isAtLoop = true; + // Must be invariant if we are at the given loop. return true; } bool isInvariant = false; - for (unsigned loop : loopStack) { - isInvariant = (loop == idx); + for (LoopId l : loopStack) { + isInvariant = (l == i); if (isInvariant) break; } @@ -135,8 +142,8 @@ case AffineExprKind::Add: case AffineExprKind::Mul: { auto binOp = a.cast(); - return isInvariantAffine(binOp.getLHS(), loopStack, ldx, atLevel) && - isInvariantAffine(binOp.getRHS(), loopStack, ldx, atLevel); + return isInvariantAffine(binOp.getLHS(), loopStack, ldx, isAtLoop) && + isInvariantAffine(binOp.getRHS(), loopStack, ldx, isAtLoop); } default: { assert(a.isa()); @@ -146,34 +153,42 @@ } /// Determines if affine expression is invariant. -static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx, - bool &atLevel) { - return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel); +static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, LoopId ldx, + bool &isAtLoop) { + return isInvariantAffine(a, env.getCurrentLoopStack(), ldx, isAtLoop); } /// Helper method to construct a permuted dimension ordering /// that adheres to the given topological sort. +// +// FIXME: does the above actually mean "dimensions", or should it say +// "level ordering"? The same dim/lvl confusion applies to all the code +// and comments in the definition below. static AffineMap permute(CodegenEnv &env, AffineMap m) { assert(m.getNumDims() + env.merger().getNumFilterLoops() == env.topSortSize() && "size mismatch"); // Construct the inverse of `m`; to avoid the asymptotic complexity // of calling `m.getPermutedPosition` repeatedly. + // + // The variable `perm` must use `unsigned` rather than `Dimension`/`Level`, + // because that's what `AffineMap::getPermutationMap` requires. + // TODO: however, `perm` should be renamed to make clear what exactly + // it's storing a permutation of. SmallVector perm; - unsigned numResults = m.getNumResults(); + const unsigned numResults = m.getNumResults(); BitVector worklist(numResults, true); - unsigned loopDepth = 1; + LoopOrd loopDepth = 1; // Construct the permutation. while (worklist.any() && loopDepth <= env.topSortSize()) { - unsigned preSize = perm.size(); - for (auto dim : worklist.set_bits()) { - bool atLevel = false; + const unsigned preSize = perm.size(); + for (unsigned dim : worklist.set_bits()) { + bool isAtLoop = false; if (m.getResult(dim).isa() || - (isInvariantAffine(m.getResult(dim), - env.getTopSortSlice(0, loopDepth), - env.topSortAt(loopDepth - 1), atLevel) && - atLevel)) { + (isInvariantAffine(m.getResult(dim), env.getLoopStackUpTo(loopDepth), + env.topSortAt(loopDepth - 1), isAtLoop) && + isAtLoop)) { // If the matching affine is constant expression or just become // invariant. We can visit the dimension now without breaking the // topSort constraint. @@ -185,8 +200,8 @@ for (unsigned i = preSize, e = perm.size(); i < e; i++) worklist.reset(perm[i]); - // Tries to entering the next loop level. - loopDepth += 1; + // Try entering the next loop in the stack. + loopDepth++; } assert(perm.size() == numResults); @@ -199,26 +214,26 @@ /// filterIdx stores the current filter loop idx should be used for the next /// compound affine sparse level, and it will be incremented by one when /// used. -static bool findAffine(Merger &merger, unsigned tensor, unsigned dim, - AffineExpr a, DimLevelType dlt, unsigned &filterLdx, +static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, + DimLevelType dlt, LoopId &filterLdx, bool setLvlFormat = true) { switch (a.getKind()) { case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - if (!isUndefDLT(merger.getDimLevelType(tensor, idx))) + const LoopId idx = a.cast().getPosition(); + if (!isUndefDLT(merger.getDimLevelType(tid, idx))) return false; // used more than once if (setLvlFormat) - merger.setDimAndDimLevelType(tensor, idx, dim, dlt); + merger.setLevelAndType(tid, idx, lvl, dlt); return true; } case AffineExprKind::Add: case AffineExprKind::Mul: case AffineExprKind::Constant: { if (!isDenseDLT(dlt) && setLvlFormat) { - assert(isUndefDLT(merger.getDimLevelType(tensor, filterLdx))); + assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx))); // Use a filter loop for sparse affine expression. - merger.setDimAndDimLevelType(tensor, filterLdx++, dim, dlt); + merger.setLevelAndType(tid, filterLdx++, lvl, dlt); } if (auto binOp = a.dyn_cast()) { @@ -226,9 +241,9 @@ // either loop index at d0 or d1. // We continue the recursion merely to check whether current affine is // admissible or not. - return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, filterLdx, + return findAffine(merger, tid, lvl, binOp.getLHS(), dlt, filterLdx, false) && - findAffine(merger, tensor, dim, binOp.getRHS(), dlt, filterLdx, + findAffine(merger, tid, lvl, binOp.getRHS(), dlt, filterLdx, false); } // Falls through when it is a constant Affine @@ -239,40 +254,61 @@ } } -/// Get the total number of compound affine expressions in affineMap that are -/// attached to the given tensor. For the following inputs: +/// Get the total number of compound affine expressions in the +/// `getMatchingIndexingMap` for the given tensor. For the following inputs: /// -/// affineMap = (d0, d1, d2) => (d0 + d1, d2) -/// tensor = ["compressed", "compressed"] +/// map = (d0, d1, d2) => (d0 + d1, d2) +/// lvlTypes = ["compressed", "compressed"] /// /// Returns 1 (because the first level is compressed and its corresponding -/// affineMap is d0 + d1) -static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, - Value tensor) { +/// indexing-expression is `d0 + d1`) +static unsigned getNumCompoundAffineOnSparseLvls(AffineMap map, Value tensor) { + // The `tensor` is not guaranted to have `RankedTensorType`, therefore + // we can't use `getRankedTensorType`/`getSparseTensorType` here. + // However, we don't need to handle `StorageSpecifierType`, so we + // can use `SparseTensorType` once we guard against non-tensors. + const auto rtp = tensor.getType().dyn_cast(); + if (!rtp) + return 0; + const SparseTensorType stt(rtp); + + // FIXME: There's some dim/lvl confusion here. The previous version of + // the code asserted that there are `lvlRank`-many expressions, but then + // the `exprs[d]` expression assumes there are in fact `dimRank`-many + // expressions. Even though `ArrayRef::operator[]` will check for OOB, + // the mismatch between the assertion and the usage belies that this code + // cannot support non-permutations. + // + // Elsewhere in this file the maps returned by + // `linalg::GenericOp::getMatchingIndexingMap` are inconsistent about + // whether they're expected to have `lvlRank`-many or `dimRank`-many + // expressions (cf., `genSubscript` vs `findSparseAnnotations`); + // so those are no help in determining which is actually intended. + // + // For now we work around this problem by asserting the two ranks agree. + const Dimension dimRank = stt.getDimRank(); + const Level lvlRank = stt.getLvlRank(); + assert(dimRank == lvlRank && "Non-permutations not currently supported"); + const auto exprs = map.getResults(); + assert(static_cast(exprs.size()) == dimRank && + "AffineMap does not have dimension-rank many results"); + (void)dimRank; unsigned num = 0; - const auto enc = getSparseTensorEncoding(tensor.getType()); - if (enc) { - const ArrayRef exps = affineMap.getResults(); - const Level lvlRank = enc.getLvlRank(); - assert(static_cast(exps.size()) == lvlRank); - for (Level l = 0; l < lvlRank; l++) { - // FIXME: `toOrigDim` is deprecated. - const Dimension d = toOrigDim(enc, l); - // FIXME: there's some dim/lvl confusion here; since `d` isn't - // guaranteed to be in bounds (for non-permutations). - if (!exps[d].isa() && !enc.isDenseLvl(l)) - num++; - } + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + const Dimension d = toOrigDim(stt.getEncoding(), l); + if (!exprs[d].isa() && !stt.isDenseLvl(l)) + num++; } return num; } -/// Get the total number of compound affine expressions attached on a sparse -/// level in the given GenericOp. -static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) { +/// Get the total number of sparse levels with compound affine +/// expressions, summed over all operands of the `GenericOp`. +static unsigned getNumCompoundAffineOnSparseLvls(linalg::GenericOp op) { unsigned num = 0; for (OpOperand &t : op->getOpOperands()) - num += getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(&t), + num += getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(&t), t.get()); return num; } @@ -281,7 +317,7 @@ OpOperand *out = op.getDpsInitOperand(0); if (getSparseTensorType(out->get()).isAllDense()) return false; - return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out), + return getNumCompoundAffineOnSparseLvls(op.getMatchingIndexingMap(out), out->get()); } @@ -292,7 +328,8 @@ /// no annotations are found or inadmissible constructs occur. static bool findSparseAnnotations(CodegenEnv &env) { bool annotated = false; - unsigned filterLdx = env.merger().getFilterLoopStartingIdx(); + // `filterLdx` may be mutated by `findAffine`. + LoopId filterLdx = env.merger().getStartingFilterLoopId(); for (OpOperand &t : env.op()->getOpOperands()) { const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); @@ -302,10 +339,12 @@ assert(!enc || lvlRank == enc.getLvlRank()); assert(static_cast(env.op().getRank(&t)) == lvlRank); for (Level l = 0; l < lvlRank; l++) { - const unsigned tensor = t.getOperandNumber(); + const TensorId tid = t.getOperandNumber(); // FIXME: `toOrigDim` is deprecated. + // FIXME: above we asserted that there are `lvlRank` many results, + // but this is assuming there are in fact `dimRank` many results instead. const AffineExpr a = map.getResult(toOrigDim(enc, l)); - if (!findAffine(env.merger(), tensor, l, a, enc.getLvlType(l), filterLdx)) + if (!findAffine(env.merger(), tid, l, a, enc.getLvlType(l), filterLdx)) return false; // inadmissible affine expression } } @@ -317,14 +356,18 @@ /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(CodegenEnv &env, unsigned n, +/// FIXME(wrengr): correct the above "index" +/// +/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by +/// `(LoopId,LoopId)`. +static bool topSortOptimal(CodegenEnv &env, LoopId n, ArrayRef iteratorTypes, std::vector &inDegree, std::vector> &adjM) { - std::vector redIt; // reduce iterator with 0 degree - std::vector parIt; // parallel iterator with 0 degree - std::vector filterIt; // filter loop with 0 degree - for (unsigned i = 0; i < n; i++) { + std::vector redIt; // reduce iterator with 0 degree + std::vector parIt; // parallel iterator with 0 degree + std::vector filterIt; // filter loop with 0 degree + for (LoopId i = 0; i < n; i++) { if (inDegree[i] == 0) { if (env.merger().isFilterLoop(i)) filterIt.push_back(i); @@ -360,7 +403,7 @@ env.topSortPushBack(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. - for (unsigned dst = 0; dst < n; dst++) { + for (LoopId dst = 0; dst < n; dst++) { if (adjM[src][dst] && --inDegree[dst] == 0) { if (env.merger().isFilterLoop(dst)) filterIt.push_back(dst); @@ -381,14 +424,17 @@ /// b = (i0 + i1) < fidx => i0 < fidx, i1 < fidx. /// The affine expression `b` is empty iff `tidx` have a value, leading to /// tidx < a = (i0 + i1) => tidx < i0, tidx < i1. +/// +/// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by +/// `(LoopId,LoopId)`. static void addAffineOrderings(std::vector> &adjM, std::vector &inDegree, AffineExpr a, - AffineExpr b, std::optional fidx, - std::optional tidx) { + AffineExpr b, std::optional fidx, + std::optional tidx) { if (!a && !b) { // Recursion leaf. assert(fidx && tidx); - unsigned f = *fidx, t = *tidx; + const LoopId f = *fidx, t = *tidx; if (!adjM[f][t]) { adjM[f][t] = true; inDegree[t]++; @@ -396,10 +442,10 @@ return; } // Picks an affine expression and expand (recurse into) it. - auto toExpand = a ? a : b; + const auto toExpand = a ? a : b; switch (toExpand.getKind()) { case AffineExprKind::DimId: { - auto idx = toExpand.cast().getPosition(); + std::optional idx = toExpand.cast().getPosition(); if (toExpand == a) addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx); else // toExpand == b @@ -424,9 +470,9 @@ } static void tryLoosenAffineDenseConstraints(linalg::GenericOp op, - std::optional &fldx, + std::optional &fldx, AffineExpr &fa, - std::optional &tldx, + std::optional &tldx, AffineExpr &ta) { // We use a heuristic here to only pick one dim expression from each // compound affine expression to establish the order between two dense @@ -467,7 +513,7 @@ OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - const unsigned n = env.merger().getNumLoops(); + const LoopId n = env.merger().getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. const auto iteratorTypes = env.op().getIteratorTypesArray(); @@ -476,7 +522,7 @@ // Get map and encoding. const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); - assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); + assert(map.getNumDims() + getNumCompoundAffineOnSparseLvls(env.op()) == n); // Skips dense inputs/outputs when not requested. const bool isDenseInput = !enc && env.op().isDpsInput(&t); @@ -489,18 +535,17 @@ // will be skipped more often. // TODO: Do we really need this? if (includesUndef(mask)) { - unsigned tensor = t.getOperandNumber(); - for (unsigned i = 0; i < n; i++) { - if (isCompressedDLT(env.dlt(tensor, i)) || - isSingletonDLT(env.dlt(tensor, i))) { - for (unsigned j = 0; j < n; j++) + const TensorId tensor = t.getOperandNumber(); + for (LoopId i = 0; i < n; i++) { + const auto dltI = env.dlt(tensor, i); + if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) { + for (LoopId j = 0; j < n; j++) if (isUndefDLT(env.dlt(tensor, j))) { adjM[i][j] = true; inDegree[j]++; } } else { - assert(isDenseDLT(env.dlt(tensor, i)) || - isUndefDLT(env.dlt(tensor, i))); + assert(isDenseDLT(dltI) || isUndefDLT(dltI)); } } } @@ -513,9 +558,11 @@ assert(!enc || lvlRank == enc.getLvlRank()); for (Level l = 0; l < lvlRank; l++) { // FIXME: `toOrigDim` is deprecated. + // FIXME: above we asserted that there are `lvlRank` many results, + // but this is assuming there are in fact `dimRank` many results instead. AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopIdx(t.getOperandNumber(), l); + std::optional tldx = + env.merger().getLoopId(t.getOperandNumber(), l); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) @@ -537,9 +584,11 @@ if (l > 0) { // FIXME: `toOrigDim` is deprecated. + // FIXME: above we asserted that there are `lvlRank` many results, + // but this is assuming there are in fact `dimRank` many results. AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopIdx(t.getOperandNumber(), l - 1); + std::optional fldx = + env.merger().getLoopId(t.getOperandNumber(), l - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -620,32 +669,37 @@ const Level lvlRank = stt.getLvlRank(); assert(static_cast(map.getNumResults()) == lvlRank); // FIXME: `toOrigDim` is deprecated. + // FIXME: above we asserted that there are `lvlRank` many results, + // but this is assuming there are in fact `dimRank` many results instead. AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1)); assert(a.getKind() == AffineExprKind::DimId); - unsigned idx = a.cast().getPosition(); - return env.getLoopIdxValue(idx); + const LoopId idx = a.cast().getPosition(); + return env.getLoopVar(idx); } /// Generates subscript for load/store on a dense or sparse tensor. static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl &args) { - linalg::GenericOp op = env.op(); - unsigned tensor = t->getOperandNumber(); - auto map = op.getMatchingIndexingMap(t); + const Location loc = env.op().getLoc(); + const TensorId tid = t->getOperandNumber(); + const auto map = env.op().getMatchingIndexingMap(t); const auto stt = getSparseTensorType(t->get()); if (stt.hasEncoding()) { - Value pidx = env.emitter().getPidxs()[tensor].back(); - assert(pidx); - args.push_back(pidx); // position index + // For sparse tensors we only push the last-level's position onto `args`. + const auto pos = env.emitter().getPosits()[tid].back(); + assert(pos); + args.push_back(pos); } else { + // For dense tensors we push all level's coordinates onto `args`. const Level lvlRank = stt.getLvlRank(); assert(static_cast(map.getNumResults()) == lvlRank); for (Level l = 0; l < lvlRank; l++) { - AffineExpr a = map.getResult(l); - args.push_back(env.emitter().genAffine(builder, a, op.getLoc())); + const auto lvlExpr = map.getResult(l); + const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr); + args.push_back(lvlCrd); } } - return env.emitter().getValBuffer()[tensor]; + return env.emitter().getValBuffer()[tid]; } /// Generates insertion code to implement dynamic tensor load. @@ -688,19 +742,22 @@ Location loc = op.getLoc(); // Direct insertion in lexicographic coordinate order. if (!env.isExpand()) { - unsigned rank = op.getRank(t); - // FIXME: It's not entirely clear what "indices" means here (i.e., - // are they "coordinates"? and if so, then are they level-coords or - // dim-coords?) - SmallVector indices; - for (unsigned i = 0; i < rank; i++) { - assert(env.emitter().getLoopIV(i)); - indices.push_back(env.emitter().getLoopIV(i)); + const LoopOrd numLoops = op.getRank(t); + // TODO: rewrite this to use `env.emitter().getLoopIVs(ivs)` + // instead. We just need to either assert that `numLoops == + // env.emitter().getCurrentDepth()`, or else update the `getLoopIVs` + // method to take an optional parameter to restrict to a smaller depth. + SmallVector ivs; + ivs.reserve(numLoops); + for (LoopOrd n = 0; n < numLoops; n++) { + const auto iv = env.emitter().getLoopIV(n); + assert(iv); + ivs.push_back(iv); } Value chain = env.getInsertionChain(); if (!env.getValidLexInsert()) { env.updateInsertionChain( - builder.create(loc, rhs, chain, indices)); + builder.create(loc, rhs, chain, ivs)); } else { // Generates runtime check for a valid lex during reduction, // to avoid inserting the identity value for empty reductions. @@ -714,7 +771,7 @@ /*else=*/true); // True branch. builder.setInsertionPointToStart(ifValidLexInsert.thenBlock()); - Value res = builder.create(loc, rhs, chain, indices); + Value res = builder.create(loc, rhs, chain, ivs); builder.create(loc, res); // False branch. builder.setInsertionPointToStart(ifValidLexInsert.elseBlock()); @@ -761,7 +818,7 @@ } /// Generates a load on a dense or sparse tensor. -static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) { +static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { // Test if the load was hoisted to a higher loop nest. Value val = env.exp(exp).val; if (val) @@ -782,7 +839,7 @@ } /// Generates a store on a dense or sparse tensor. -static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp, +static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp, Value rhs) { linalg::GenericOp op = env.op(); Location loc = op.getLoc(); @@ -830,7 +887,7 @@ } /// Generates an invariant value. -inline static Value genInvariantValue(CodegenEnv &env, unsigned exp) { +inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { return env.exp(exp).val; } @@ -840,10 +897,10 @@ /// exception of index computations, which need to be relinked to actual /// inlined cloned code. static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, - Value e, unsigned ldx) { + Value e, LoopId ldx) { if (Operation *def = e.getDefiningOp()) { if (auto indexOp = dyn_cast(def)) - return env.getLoopIdxValue(indexOp.getDim()); + return env.getLoopVar(indexOp.getDim()); if (def->getBlock() == block) { for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { rewriter.updateRootInPlace(def, [&]() { @@ -857,52 +914,52 @@ } /// Recursively generates tensor expression. -static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, - unsigned ldx) { +static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, + LoopId ldx) { linalg::GenericOp op = env.op(); Location loc = op.getLoc(); - if (exp == -1u) + if (e == kInvalidId) return Value(); - if (env.exp(exp).kind == Kind::kTensor) - return genTensorLoad(env, rewriter, exp); - if (env.exp(exp).kind == Kind::kInvariant) - return genInvariantValue(env, exp); - if (env.exp(exp).kind == Kind::kIndex) - return env.getLoopIdxValue(env.exp(exp).index); + const TensorExp &exp = env.exp(e); + const auto kind = exp.kind; + if (kind == Kind::kTensor) + return genTensorLoad(env, rewriter, e); + if (kind == Kind::kInvariant) + return genInvariantValue(env, e); + if (kind == Kind::kLoopVar) + return env.getLoopVar(exp.loop); - if (env.exp(exp).kind == Kind::kReduce) - env.startCustomReduc(exp); // enter custom + if (kind == Kind::kReduce) + env.startCustomReduc(e); // enter custom - Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx); - Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx); - Value ee = env.merger().buildExp(rewriter, loc, exp, v0, v1); - if (ee && (env.exp(exp).kind == Kind::kUnary || - env.exp(exp).kind == Kind::kBinary || - env.exp(exp).kind == Kind::kBinaryBranch || - env.exp(exp).kind == Kind::kReduce || - env.exp(exp).kind == Kind::kSelect)) + Value v0 = genExp(env, rewriter, exp.children.e0, ldx); + Value v1 = genExp(env, rewriter, exp.children.e1, ldx); + Value ee = env.merger().buildExp(rewriter, loc, e, v0, v1); + if (ee && (kind == Kind::kUnary || kind == Kind::kBinary || + kind == Kind::kBinaryBranch || kind == Kind::kReduce || + kind == Kind::kSelect)) ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); - if (env.exp(exp).kind == Kind::kReduce) + if (kind == Kind::kReduce) env.endCustomReduc(); // exit custom - if (env.exp(exp).kind == kSelect) { - assert(!env.exp(exp).val); - env.exp(exp).val = v0; // Preserve value for later use. + if (kind == kSelect) { + assert(!exp.val); + env.exp(e).val = v0; // Preserve value for later use. } return ee; } /// Hoists loop invariant tensor loads for which indices have been exhausted. -static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp, - unsigned ldx, bool atStart) { - if (exp == -1u) +static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp, + LoopId ldx, bool atStart) { + if (exp == kInvalidId) return; if (env.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. - bool atLevel = ldx == -1u; + bool isAtLoop = ldx == kInvalidId; linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); @@ -911,20 +968,21 @@ assert(static_cast(map.getNumResults()) == lvlRank); for (Level l = 0; l < lvlRank; l++) { // FIXME: `toOrigDim` is deprecated. + // FIXME: above we asserted that there are `lvlRank` many results, + // but this is assuming there are in fact `dimRank` many results instead. AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l)); - std::optional sldx = - env.merger().getLoopIdx(t.getOperandNumber(), l); + const auto sldx = env.merger().getLoopId(t.getOperandNumber(), l); if (sldx && env.merger().isFilterLoop(*sldx)) { - if (!env.getLoopIdxValue(*sldx)) + if (!env.getLoopVar(*sldx)) // The filter loops has not been constructed. return; if (*sldx == ldx) - atLevel = true; - } else if (!isInvariantAffine(env, a, ldx, atLevel)) + isAtLoop = true; + } else if (!isInvariantAffine(env, a, ldx, isAtLoop)) return; // still in play } - // All exhausted at this level (atLevel denotes exactly at this level). - if (!atLevel) + // All exhausted at this level (isAtLoop denotes exactly at this LoopId). + if (!isAtLoop) return; OpOperand *lhs = op.getDpsInitOperand(0); if (lhs == &t) { @@ -944,14 +1002,14 @@ env.exp(exp).val = atStart ? genTensorLoad(env, builder, exp) : Value(); } } else if (env.exp(exp).kind != Kind::kInvariant && - env.exp(exp).kind != Kind::kIndex) { + env.exp(exp).kind != Kind::kLoopVar) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. if (env.exp(exp).kind == Kind::kReduce) env.startCustomReduc(exp); // enter custom - unsigned e0 = env.exp(exp).children.e0; - unsigned e1 = env.exp(exp).children.e1; + const ExprId e0 = env.exp(exp).children.e0; + const ExprId e1 = env.exp(exp).children.e1; genInvariants(env, builder, e0, ldx, atStart); genInvariants(env, builder, e1, ldx, atStart); if (env.exp(exp).kind == Kind::kReduce) @@ -960,7 +1018,7 @@ } /// Generates an expanded access pattern in innermost dimension. -static void genExpand(CodegenEnv &env, OpBuilder &builder, unsigned at, +static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at, bool atStart) { linalg::GenericOp op = env.op(); OpOperand *lhs = op.getDpsInitOperand(0); @@ -987,7 +1045,7 @@ r.getResult(3)); } else { SmallVector indices; - for (unsigned i = 0; i < at; i++) + for (LoopOrd i = 0; i < at; i++) indices.push_back(env.emitter().getLoopIV(i)); Value values = env.getExpandValues(); Value filled = env.getExpandFilled(); @@ -1029,34 +1087,35 @@ /// Generates a for-loop on a single index. static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, - bool isInner, unsigned idx, ArrayRef tids, - ArrayRef dims) { + bool isInner, LoopId ldx, ArrayRef tids, + ArrayRef lvls) { linalg::GenericOp op = env.op(); Location loc = op.getLoc(); auto iteratorTypes = op.getIteratorTypesArray(); - bool isSparse = llvm::any_of(tids, [idx, &env](size_t tid) { - return isCompressedDLT(env.dlt(tid, idx)) || - isSingletonDLT(env.dlt(tid, idx)); + bool isSparse = llvm::any_of(tids, [ldx, &env](TensorId tid) { + const auto dlt = env.dlt(tid, ldx); + return isCompressedDLT(dlt) || isSingletonDLT(dlt); }); bool isParallel = isParallelFor(env, isOuter, isSparse); Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { - if (env.merger().isFilterLoop(idx)) { - size_t tid = tids.front(), dim = dims.front(); - // tids/dims must only have one value because filter loops only + if (env.merger().isFilterLoop(ldx)) { + const TensorId tid = tids.front(); + const Level lvl = lvls.front(); + // tids/lvls must only have one value because filter loops only // corresponding to the one and only sparse tensor level. - assert(isSparse && tids.size() == 1 && dims.size() == 1); + assert(isSparse && tids.size() == 1 && lvls.size() == 1); OpOperand *t = &op->getOpOperand(tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. // FIXME: `toOrigDim` is deprecated. AffineExpr a = - op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); - return env.emitter().enterFilterLoopOverTensorAtDim(builder, loc, tid, - dim, a, reduc); + op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, lvl)); + return env.emitter().enterFilterLoopOverTensorAtLvl(builder, loc, tid, + lvl, a, reduc); } - return env.emitter().enterLoopOverTensorAtDim(builder, loc, tids, dims, + return env.emitter().enterLoopOverTensorAtLvl(builder, loc, tids, lvls, reduc, isParallel); }); assert(loop); @@ -1064,14 +1123,14 @@ } /// Emit a while-loop for co-iteration over multiple indices. -static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx, - bool needsUniv, ArrayRef tids, - ArrayRef dims) { +static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, LoopId idx, + bool needsUniv, ArrayRef tids, + ArrayRef lvls) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. - return env.emitter().enterCoIterationOverTensorsAtDims( - builder, env.op().getLoc(), tids, dims, needsUniv, reduc); + return env.emitter().enterCoIterationOverTensorsAtLvls( + builder, env.op().getLoc(), tids, lvls, needsUniv, reduc); }); assert(loop); return loop; @@ -1079,21 +1138,21 @@ /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. -static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, - bool needsUniv, ArrayRef tids, - ArrayRef dims, bool isFor) { - assert(tids.size() == dims.size()); - unsigned idx = env.topSortAt(at); +static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, + bool needsUniv, ArrayRef tids, + ArrayRef lvls, bool isFor) { + assert(tids.size() == lvls.size()); + const LoopId idx = env.topSortAt(at); if (isFor) { bool isOuter = at == 0; bool isInner = at == env.topSortSize() - 1; - return genFor(env, builder, isOuter, isInner, idx, tids, dims); + return genFor(env, builder, isOuter, isInner, idx, tids, lvls); } - return genWhile(env, builder, idx, needsUniv, tids, dims); + return genWhile(env, builder, idx, needsUniv, tids, lvls); } /// Generates the induction structure for a while-loop. -static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx, +static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx, bool needsUniv, BitVector &induction, scf::WhileOp whileOp) { Location loc = env.op().getLoc(); @@ -1133,26 +1192,26 @@ } /// Generates a single if-statement within a while-loop. -static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx, +static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, BitVector &conditions) { Location loc = env.op().getLoc(); SmallVector types; Value cond; - for (unsigned b = 0, be = conditions.size(); b < be; b++) { + for (TensorLoopId b = 0, be = conditions.size(); b < be; b++) { if (!conditions[b]) continue; - unsigned tensor = env.merger().tensor(b); - assert(idx == env.merger().index(b)); + const TensorId tid = env.merger().tensor(b); + assert(ldx == env.merger().loop(b)); Value clause; - if (isCompressedDLT(env.dlt(b)) || isSingletonDLT(env.dlt(b))) { - auto dim = *env.merger().getDimNum(tensor, idx); - Value op1 = env.emitter().getCoord()[tensor][dim]; - Value op2 = env.getLoopIdxValue(idx); - clause = builder.create(loc, arith::CmpIPredicate::eq, op1, - op2); + const auto dlt = env.dlt(b); + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + const Level lvl = *env.merger().getLvl(tid, ldx); + const Value crd = env.emitter().getCoords()[tid][lvl]; + const Value lvar = env.getLoopVar(ldx); + clause = builder.create(loc, arith::CmpIPredicate::eq, crd, + lvar); } else { - assert(isDenseDLT(env.merger().getDimLevelType(b)) || - isUndefDLT(env.merger().getDimLevelType(b))); + assert(isDenseDLT(dlt) || isUndefDLT(dlt)); clause = constantI1(builder, loc, true); } cond = cond ? builder.create(loc, cond, clause) : clause; @@ -1202,41 +1261,40 @@ /// Starts a loop sequence at given level. Returns true if /// the universal loop index must be maintained at this level. -static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, - unsigned at, unsigned idx, unsigned ldx, - unsigned lts) { - assert(!env.getLoopIdxValue(idx)); +static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, + LoopOrd at, LoopId idx, LoopId ldx, LatSetId lts) { + assert(!env.getLoopVar(idx)); // Emit invariants at this loop sequence level. genInvariants(env, builder, exp, ldx, /*atStart=*/true); // Emit access pattern expansion for sparse tensor output. genExpand(env, builder, at, /*atStart=*/true); // Emit further intitialization at this loop sequence level. - unsigned l0 = env.set(lts)[0]; + const LatPointId l0 = env.set(lts)[0]; bool needsUniv = false; - SmallVector tids; - SmallVector dims; - env.merger().foreachTidDimPairInBits( - env.lat(l0).bits, [&](unsigned b, unsigned tid, - std::optional dim, DimLevelType dlt) { - assert(env.merger().index(b) == idx); + SmallVector tids; + SmallVector lvls; + env.merger().foreachTensorLoopId( + env.lat(l0).bits, [&](TensorLoopId b, TensorId tid, + std::optional lvl, DimLevelType dlt) { + assert(env.merger().loop(b) == idx); if (isDenseDLT(dlt) || isUndefDLT(dlt)) { needsUniv = true; } else { - // sparse/singleton dim levels. + // sparse/singleton levels. tids.push_back(tid); - dims.push_back(*dim); + lvls.push_back(*lvl); } }); - env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, dims); + env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); // Maintain the universal index only if it is actually // consumed by a subsequent lattice point. if (needsUniv) { unsigned lsize = env.set(lts).size(); for (unsigned i = 1; i < lsize; i++) { - unsigned li = env.set(lts)[i]; + const LatPointId li = env.set(lts)[i]; if (!env.merger().hasAnySparse(env.lat(li).simple)) return true; } @@ -1245,23 +1303,25 @@ } static void genConstantDenseAddressFromLevel(CodegenEnv &env, - OpBuilder &builder, unsigned tid, - Level lvl) { + OpBuilder &builder, TensorId tid, + Level startLvl) { // TODO: Handle affine expression on output tensor. linalg::GenericOp op = env.op(); assert(tid < op.getNumDpsInputs()); OpOperand *input = op.getDpsInputOperands()[tid]; - ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); + const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); const auto enc = getSparseTensorEncoding(input->get().getType()); if (enc) { + const Location loc = op.getLoc(); + const TensorId tid = input->getOperandNumber(); const Level lvlRank = enc.getLvlRank(); - assert(affines.size() == static_cast(lvlRank)); - for (Level l = lvl; l < lvlRank; l++) { + assert(lvlExprs.size() == static_cast(lvlRank)); + // FIXME: there is dim/lvl confusion here + for (Level l = startLvl; l < lvlRank; l++) { // FIXME: `toOrigDim` is deprecated. - AffineExpr affine = affines[toOrigDim(enc, l)]; - if (enc.isDenseLvl(l) && affine.isa()) - env.emitter().genDenseAffineAddressAtCurLevel( - builder, op.getLoc(), input->getOperandNumber(), l, affine); + AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)]; + if (enc.isDenseLvl(l) && lvlExpr.isa()) + env.emitter().genDenseAffineAddress(builder, loc, tid, l, lvlExpr); else return; // break on first non-dense non-constant level } @@ -1274,43 +1334,44 @@ // starting from the first level as they do not depend on any thing. // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two // levels can be determined before loops. - for (unsigned tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) + for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } /// Return true if the lattices bit can be iterated by a for loop. -static bool translateBitsToTidDimPairs( - CodegenEnv &env, unsigned li, unsigned idx, SmallVectorImpl &tids, - SmallVectorImpl &dims, SmallVectorImpl &affineTids, - SmallVectorImpl &affineDims, SmallVectorImpl &exps) { +static bool translateBitsToTidLvlPairs( + CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl &tids, + SmallVectorImpl &lvls, SmallVectorImpl &affineTids, + SmallVectorImpl &affineLvls, SmallVectorImpl &exps) { const BitVector &all = env.lat(li).bits; const BitVector &simple = env.lat(li).simple; + const TensorId outTid = env.merger().getOutTensorID(); + const std::optional outLvl = env.merger().getLvl(outTid, ldx); unsigned numloopCond = 0; - // Converts bits to array + dim pair - env.merger().foreachTidDimPairInBits( - all, [&, idx](unsigned b, unsigned tid, std::optional dim, + env.merger().foreachTensorLoopId( + all, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt) { if (simple.test(b)) { if (isUndefDLT(dlt)) { - // An undefined dlt in the lattices, we probably mean to iterate - // based on the dim of output tensor. - // E.g., this could be a synthetic tensor (for invariants and sparse + // An undefined dlt in the lattices, we probably mean to + // iterate based on the level of output tensor. E.g., this + // could be a synthetic tensor (for invariants and sparse // output tensor). // out[i][j] = invariant; or a broadcast // out[i][j] = in[i] (j is undef for input) - tid = env.merger().getOutTensorID(); - dim = env.merger().getDimNum(tid, idx); - // Skips invalid dim (e.g., when this is a zero ranked tensor). - if (!dim) + tid = outTid; + lvl = outLvl; + // Skips invalid lvl (e.g., when this is a zero ranked tensor). + if (!lvl) return; } tids.push_back(tid); - dims.push_back(*dim); + lvls.push_back(*lvl); numloopCond++; } else if (isDenseDLT(dlt)) { tids.push_back(tid); - dims.push_back(*dim); + lvls.push_back(*lvl); } else { assert(isUndefDLT(dlt)); linalg::GenericOp op = env.op(); @@ -1330,15 +1391,15 @@ for (Level l = 0; l < lvlRank; l++) { // FIXME: `toOrigDim` is deprecated. AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; - // Skip simple affine expression and non dense dimensions (which has - // it own filter loop). + // Skip simple affine expression and non-dense levels (which + // have their own filter loop). if (exp.isa() || !stt.isDenseLvl(l)) continue; // Constant affine expression are handled in genLoop if (!exp.isa()) { - bool atLevel = false; - if (isInvariantAffine(env, exp, idx, atLevel) && atLevel) { + bool isAtLoop = false; + if (isInvariantAffine(env, exp, ldx, isAtLoop) && isAtLoop) { // If the compound affine is invariant and we are right at the // level. We need to generate the address according to the // affine expression. This is also the best place we can do it @@ -1349,7 +1410,7 @@ // might be accepting out-of-order access between consecutive // dense levels. affineTids.push_back(tid); - affineDims.push_back(l); + affineLvls.push_back(l); exps.push_back(exp); } } @@ -1357,13 +1418,12 @@ } }); - if (isDenseDLT(env.dlt(env.merger().getOutTensorID(), idx))) { + if (isDenseDLT(env.dlt(outTid, ldx))) { // Note that we generate dense indices of the output tensor // unconditionally, since they may not appear in the lattice, but may be // needed for linearized env. - auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx); - tids.push_back(env.merger().getOutTensorID()); - dims.push_back(dim); + tids.push_back(outTid); + lvls.push_back(*outLvl); } assert(numloopCond > 0); @@ -1373,33 +1433,33 @@ } /// Starts a single loop in current sequence. -static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, - unsigned li, bool needsUniv) { - // The set of tensors + dims to generate loops on - SmallVector tids, dims; +static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, LoopOrd at, + LatPointId li, bool needsUniv) { + // The set of tensors + lvls to generate loops on + SmallVector tids, affineTids; + SmallVector lvls, affineLvls; // The set of dense tensors with non-trivial affine expression that just // becomes invariant and the address shall now be generated at the current // level. - SmallVector affineTids, affineDims; SmallVector affines; - bool isFor = translateBitsToTidDimPairs( - env, li, env.topSortAt(at), tids, dims, affineTids, affineDims, affines); + bool isFor = translateBitsToTidLvlPairs( + env, li, env.topSortAt(at), tids, lvls, affineTids, affineLvls, affines); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor); - for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { - env.emitter().genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(), - tid, dim, exp); + Operation *loop = genLoop(env, builder, at, needsUniv, tids, lvls, isFor); + Location loc = env.op().getLoc(); + for (auto [tid, lvl, exp] : llvm::zip(affineTids, affineLvls, affines)) { + env.emitter().genDenseAffineAddress(builder, loc, tid, lvl, exp); } - // Until now, we have entered every pair in {cond, extra, - // affine}Tids/Dims. The addresses of the upcoming levels which are dependent + // Until now, we have entered every pair in {cond, extra, + // affine}Tids/Lvls. The addresses of the upcoming levels which are dependent // on constant affines expression may now be determined. - auto allTids = llvm::concat(tids, affineTids); - auto allDims = llvm::concat(dims, affineDims); - for (auto [tid, dim] : llvm::zip(allTids, allDims)) { + auto allTids = llvm::concat(tids, affineTids); + auto allLvls = llvm::concat(lvls, affineLvls); + for (auto [tid, lvl] : llvm::zip(allTids, allLvls)) { if (tid != env.merger().getOutTensorID()) - genConstantDenseAddressFromLevel(env, builder, tid, dim + 1); + genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); } return loop; @@ -1407,7 +1467,7 @@ /// Ends a single loop in current sequence. Returns new values for needsUniv. static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, - unsigned idx, unsigned li, bool needsUniv) { + LoopId idx, LatPointId li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); @@ -1428,9 +1488,9 @@ } /// Ends a loop sequence at given level. -static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, - unsigned at, unsigned idx, unsigned ldx) { - assert(env.getLoopIdxValue(idx) == nullptr); +static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, + LoopOrd at, LoopId idx, LoopId ldx) { + assert(!env.getLoopVar(idx)); env.emitter().exitCurrentLoopSeq(); // Unmark bookkeeping of invariants and loop index. genInvariants(env, builder, exp, ldx, /*atStart=*/false); @@ -1441,20 +1501,21 @@ /// Recursively generates code while computing iteration lattices in order /// to manage the complexity of implementing co-iteration over unions /// and intersections of sparse iterations spaces. -static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, - unsigned at) { +static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, + LoopOrd at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == env.topSortSize()) { - unsigned ldx = env.topSortAt(at - 1); + const LoopId ldx = env.topSortAt(at - 1); Value rhs = genExp(env, rewriter, exp, ldx); genTensorStore(env, rewriter, exp, rhs); return; } // Construct iteration lattices for current loop index, with L0 at top. - unsigned idx = env.topSortAt(at); - unsigned ldx = at == 0 ? -1u : env.topSortAt(at - 1); - unsigned lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); + const LoopId idx = env.topSortAt(at); + const LoopId ldx = at == 0 ? kInvalidId : env.topSortAt(at - 1); + const LatSetId lts = + env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); // Start a loop sequence. bool needsUniv = startLoopSeq(env, rewriter, exp, at, idx, ldx, lts); @@ -1463,7 +1524,7 @@ unsigned lsize = env.set(lts).size(); for (unsigned i = 0; i < lsize; i++) { // Start a loop. - unsigned li = env.set(lts)[i]; + const LatPointId li = env.set(lts)[i]; Operation *loop = startLoop(env, rewriter, at, li, needsUniv); // Visit all lattices points with Li >= Lj to generate the @@ -1473,8 +1534,8 @@ Value insInput = env.getInsertionChain(); bool isWhile = dyn_cast(loop) != nullptr; for (unsigned j = 0; j < lsize; j++) { - unsigned lj = env.set(lts)[j]; - unsigned ej = env.lat(lj).exp; + const LatPointId lj = env.set(lts)[j]; + const ExprId ej = env.lat(lj).exp; if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. if (isWhile) { @@ -1539,12 +1600,12 @@ return failure(); // Sets up a code generation environment. - unsigned numTensors = op->getNumOperands(); - unsigned numLoops = op.getNumLoops(); - unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op); + const unsigned numTensors = op->getNumOperands(); + const unsigned numLoops = op.getNumLoops(); + const unsigned numFilterLoops = getNumCompoundAffineOnSparseLvls(op); CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); - // Detects sparse annotations and translates the per-dimension sparsity + // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. if (!findSparseAnnotations(env)) return failure(); @@ -1566,11 +1627,11 @@ // computation. Must be ordered from more strict to less strict. // Ideally (though might not be guaranteed), the eariler a constraint mask // can be satisfied, the faster the generated kernel will be. - const auto allMask = { + const auto allMasks = { SortMask::kIncludeAll, SortMask::kIncludeDense, SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, SortMask::kIncludeUndef, SortMask::kSparseOnly}; - for (auto mask : allMask) { + for (const SortMask mask : allMasks) { if (computeIterationGraph(env, mask)) { hasCycle = false; if (env.isAdmissibleTopoOrder()) { @@ -1601,7 +1662,7 @@ // sparse input tensor in succession until an acylic // iteration graph results. for (OpOperand *t : env.op().getDpsInputOperands()) { - unsigned tensor = t->getOperandNumber(); + const TensorId tid = t->getOperandNumber(); Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t)) @@ -1623,7 +1684,7 @@ srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); rewriter.updateRootInPlace( - env.op(), [&]() { env.op()->setOperand(tensor, convert); }); + env.op(), [&]() { env.op()->setOperand(tid, convert); }); rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -30,7 +30,7 @@ // Leaf. case kTensor: case kInvariant: - case kIndex: + case kLoopVar: return ExpArity::kNullary; case kAbsF: case kAbsC: @@ -98,20 +98,20 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) +TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o) : kind(k), val(v), op(o) { switch (kind) { // Leaf. case kTensor: - assert(x != -1u && y == -1u && !v && !o); + assert(x != kInvalidId && y == kInvalidId && !v && !o); tensor = x; break; case kInvariant: - assert(x == -1u && y == -1u && v && !o); + assert(x == kInvalidId && y == kInvalidId && v && !o); break; - case kIndex: - assert(x != -1u && y == -1u && !v && !o); - index = x; + case kLoopVar: + assert(x != kInvalidId && y == kInvalidId && !v && !o); + loop = x; break; // Unary operations. case kAbsF: @@ -134,7 +134,7 @@ case kNegI: case kCIm: case kCRe: - assert(x != -1u && y == -1u && !v && !o); + assert(x != kInvalidId && y == kInvalidId && !v && !o); children.e0 = x; children.e1 = y; break; @@ -149,20 +149,20 @@ case kCastIdx: case kTruncI: case kBitCast: - assert(x != -1u && y == -1u && v && !o); + assert(x != kInvalidId && y == kInvalidId && v && !o); children.e0 = x; children.e1 = y; break; case kBinaryBranch: case kSelect: - assert(x != -1u && y == -1u && !v && o); + assert(x != kInvalidId && y == kInvalidId && !v && o); children.e0 = x; children.e1 = y; break; case kUnary: // No assertion on y can be made, as the branching paths involve both // a unary (mapSet) and binary (takeDisj) pathway. - assert(x != -1u && !v && o); + assert(x != kInvalidId && !v && o); children.e0 = x; children.e1 = y; break; @@ -186,82 +186,89 @@ case kShrS: case kShrU: case kShlI: - assert(x != -1u && y != -1u && !v && !o); + assert(x != kInvalidId && y != kInvalidId && !v && !o); children.e0 = x; children.e1 = y; break; case kBinary: case kReduce: - assert(x != -1u && y != -1u && !v && o); + assert(x != kInvalidId && y != kInvalidId && !v && o); children.e0 = x; children.e1 = y; break; } } -LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) - : bits(n, false), exp(e) { +LatPoint::LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {} + +LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i, + ExprId e) + : bits(numLoops * numTensors, false), exp(e) { + assert(t < numTensors && i < numLoops); + const TensorLoopId b = numTensors * i + t; bits.set(b); } -LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), exp(e) {} - -Merger::Merger(unsigned t, unsigned l, unsigned fl) - : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), - numNativeLoops(l), numLoops(l + fl), hasSparseOut(false), - dimTypes(numTensors, +Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, + unsigned numFilterLoops) + : outTensor(numInputOutputTensors - 1), + syntheticTensor(numInputOutputTensors), + numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), + numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), + lvlTypes(numTensors, std::vector(numLoops, DimLevelType::Undef)), - loopIdxToDim(numTensors, std::vector>( - numLoops, std::nullopt)), - dimToLoopIdx(numTensors, std::vector>( - numLoops, std::nullopt)) {} + loopToLvl(numTensors, + std::vector>(numLoops, std::nullopt)), + lvlToLoop(numTensors, + std::vector>(numLoops, std::nullopt)) {} //===----------------------------------------------------------------------===// // Lattice methods. //===----------------------------------------------------------------------===// -unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, - Operation *op) { - unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1, v, op)); +ExprId Merger::addExp(Kind k, unsigned x, ExprId y, Value v, Operation *op) { + const ExprId e = tensorExps.size(); + assert((k != kTensor || x < numTensors) && (k != kLoopVar || x < numLoops)); + tensorExps.emplace_back(k, x, y, v, op); return e; } -unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { +LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { assert(t < numTensors && i < numLoops); - unsigned p = latPoints.size(); - latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); + const LatPointId p = latPoints.size(); + latPoints.emplace_back(numTensors, numLoops, t, i, e); return p; } -unsigned Merger::addSet() { - unsigned s = latSets.size(); +LatSetId Merger::addSet() { + const LatSetId s = latSets.size(); latSets.emplace_back(); return s; } -unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, - Operation *op) { - unsigned p = latPoints.size(); - BitVector nb = BitVector(latPoints[p0].bits); - nb |= latPoints[p1].bits; - unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); - latPoints.push_back(LatPoint(nb, e)); +LatPointId Merger::conjLatPoint(Kind kind, LatPointId p0, LatPointId p1, + Operation *op) { + const LatPointId p = latPoints.size(); + BitVector bits(latPoints[p0].bits); + bits |= latPoints[p1].bits; + const ExprId e = + addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); + latPoints.emplace_back(bits, e); return p; } -unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { - unsigned s = addSet(); - for (unsigned p0 : latSets[s0]) - for (unsigned p1 : latSets[s1]) +LatSetId Merger::takeConj(Kind kind, LatSetId s0, LatSetId s1, Operation *op) { + const LatSetId s = addSet(); + for (const LatPointId p0 : latSets[s0]) + for (const LatPointId p1 : latSets[s1]) latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); return s; } -unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { - unsigned s = takeConj(kind, s0, s1, op); +LatSetId Merger::takeDisj(Kind kind, LatSetId s0, LatSetId s1, Operation *op) { + const LatSetId s = takeConj(kind, s0, s1, op); // Followed by all in s0. - for (unsigned p : latSets[s0]) + for (const LatPointId p : latSets[s0]) latSets[s].push_back(p); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices @@ -272,56 +279,56 @@ else if (kind == kSubI) s1 = mapSet(kNegI, s1); // Followed by all in s1. - for (unsigned p : latSets[s1]) + for (const LatPointId p : latSets[s1]) latSets[s].push_back(p); return s; } -unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, +LatSetId Merger::takeCombi(Kind kind, LatSetId s0, LatSetId s1, Operation *orig, bool includeLeft, Kind ltrans, Operation *opleft, bool includeRight, Kind rtrans, Operation *opright) { - unsigned s = takeConj(kind, s0, s1, orig); + const LatSetId s = takeConj(kind, s0, s1, orig); // Left Region. if (includeLeft) { if (opleft) s0 = mapSet(ltrans, s0, Value(), opleft); - for (unsigned p : latSets[s0]) + for (const LatPointId p : latSets[s0]) latSets[s].push_back(p); } // Right Region. if (includeRight) { if (opright) s1 = mapSet(rtrans, s1, Value(), opright); - for (unsigned p : latSets[s1]) + for (const LatPointId p : latSets[s1]) latSets[s].push_back(p); } return s; } -unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { +LatSetId Merger::mapSet(Kind kind, LatSetId s0, Value v, Operation *op) { assert(kAbsF <= kind && kind <= kSelect); - unsigned s = addSet(); - for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, latPoints[p].exp, v, op); - latPoints.push_back(LatPoint(latPoints[p].bits, e)); + const LatSetId s = addSet(); + for (const LatPointId p : latSets[s0]) { + const ExprId e = addExp(kind, latPoints[p].exp, v, op); + latPoints.emplace_back(latPoints[p].bits, e); latSets[s].push_back(latPoints.size() - 1); } return s; } -unsigned Merger::optimizeSet(unsigned s0) { - unsigned s = addSet(); +LatSetId Merger::optimizeSet(LatSetId s0) { + const LatSetId s = addSet(); assert(!latSets[s0].empty()); - unsigned p0 = latSets[s0][0]; - for (unsigned p1 : latSets[s0]) { + const LatPointId p0 = latSets[s0][0]; + for (const LatPointId p1 : latSets[s0]) { bool add = true; if (p0 != p1) { - // Is this a straightforward copy? - unsigned e = latPoints[p1].exp; + // Check whether this is a straightforward copy. + const ExprId e = latPoints[p1].exp; if (expIsTensor(e, outTensor)) continue; - // Conjunction already covered? - for (unsigned p2 : latSets[s]) { + // Check whether this conjunction is already covered. + for (const LatPointId p2 : latSets[s]) { assert(!latGT(p1, p2)); // Lj => Li would be bad if (onlyDenseDiff(p2, p1)) { add = false; @@ -333,30 +340,30 @@ if (add) latSets[s].push_back(p1); } - for (unsigned p : latSets[s]) + for (const LatPointId p : latSets[s]) latPoints[p].simple = simplifyCond(s, p); return s; } -BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { +BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { // First determine if this lattice point is a *singleton*, i.e., // the last point in a lattice, no other is less than this one. bool isSingleton = true; - for (unsigned p1 : latSets[s0]) { + for (const LatPointId p1 : latSets[s0]) { if (p0 != p1 && latGT(p0, p1)) { isSingleton = false; break; } } - BitVector simple = latPoints[p0].bits; + BitVector simple(latPoints[p0].bits); bool reset = isSingleton && hasAnySparse(simple); - unsigned be = simple.size(); - unsigned offset = 0; // relative to the end + const TensorLoopId be = simple.size(); + TensorLoopId offset = 0; // relative to the end if (!reset) - // Starts resetting from a dense dimension, so that the first bit (if kept) - // is not undefined dimension type. - for (unsigned b = 0; b < be; b++) { + // Starts resetting from a dense level, so that the first bit (if kept) + // is not undefined level-type. + for (TensorLoopId b = 0; b < be; b++) { if (simple[b] && isDenseDLT(getDimLevelType(b))) { offset = be - b - 1; // relative to the end break; @@ -365,24 +372,26 @@ // Now apply the two basic rules. We also iterate the bits reversely to always // keep the rightmost bit (which could possibly be a synthetic tensor). - for (unsigned b = be - 1 - offset, i = 0; i < be; + for (TensorLoopId b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { - if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) && - !isSingletonDLT(getDimLevelType(b)))) { - if (reset) - simple.reset(b); - reset = true; + if (simple[b]) { + const auto dlt = getDimLevelType(b); + if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { + if (reset) + simple.reset(b); + reset = true; + } } } return simple; } -bool Merger::latGT(unsigned i, unsigned j) const { +bool Merger::latGT(LatPointId i, LatPointId j) const { const BitVector &bitsi = latPoints[i].bits; const BitVector &bitsj = latPoints[j].bits; assert(bitsi.size() == bitsj.size()); if (bitsi.count() > bitsj.count()) { - for (unsigned b = 0, be = bitsj.size(); b < be; b++) + for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) if (bitsj[b] && !bitsi[b]) return false; return true; @@ -390,13 +399,13 @@ return false; } -bool Merger::onlyDenseDiff(unsigned i, unsigned j) { - BitVector tmp = latPoints[j].bits; +bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { + BitVector tmp(latPoints[j].bits); tmp ^= latPoints[i].bits; return !hasAnySparse(tmp); } -bool Merger::expContainsTensor(unsigned e, unsigned t) const { +bool Merger::expContainsTensor(ExprId e, TensorId t) const { if (tensorExps[e].kind == kTensor) return tensorExps[e].tensor == t; @@ -404,23 +413,23 @@ case ExpArity::kNullary: return false; case ExpArity::kUnary: { - unsigned op = tensorExps[e].children.e0; - if (expIsTensor(op, t)) + const ExprId e0 = tensorExps[e].children.e0; + if (expIsTensor(e0, t)) return true; - return expContainsTensor(op, t); + return expContainsTensor(e0, t); } case ExpArity::kBinary: { - unsigned op1 = tensorExps[e].children.e0; - unsigned op2 = tensorExps[e].children.e1; - if (expIsTensor(op1, t) || expIsTensor(op2, t)) + const ExprId e0 = tensorExps[e].children.e0; + const ExprId e1 = tensorExps[e].children.e1; + if (expIsTensor(e0, t) || expIsTensor(e1, t)) return true; - return expContainsTensor(op1, t) || expContainsTensor(op2, t); + return expContainsTensor(e0, t) || expContainsTensor(e1, t); } } llvm_unreachable("unexpected arity"); } -bool Merger::hasNegateOnOut(unsigned e) const { +bool Merger::hasNegateOnOut(ExprId e) const { switch (tensorExps[e].kind) { case kNegF: case kNegC: @@ -446,13 +455,14 @@ llvm_unreachable("unexpected kind"); } -bool Merger::isSingleCondition(unsigned t, unsigned e) const { +bool Merger::isSingleCondition(TensorId t, ExprId e) const { + assert(t < numTensors && e < tensorExps.size()); switch (tensorExps[e].kind) { // Leaf. case kTensor: return tensorExps[e].tensor == t; case kInvariant: - case kIndex: + case kLoopVar: return false; // Unary operations. case kAbsF: @@ -531,10 +541,12 @@ } bool Merger::hasAnySparse(const BitVector &bits) const { - for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && (isCompressedDLT(getDimLevelType(b)) || - isSingletonDLT(getDimLevelType(b)))) - return true; + for (TensorLoopId b = 0, be = bits.size(); b < be; b++) + if (bits[b]) { + const auto dlt = getDimLevelType(b); + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) + return true; + } return false; } @@ -551,7 +563,7 @@ return "tensor"; case kInvariant: return "invariant"; - case kIndex: + case kLoopVar: return "index"; // Unary operations. case kAbsF: @@ -641,7 +653,7 @@ llvm_unreachable("unexpected kind for symbol"); } -void Merger::dumpExp(unsigned e) const { +void Merger::dumpExp(ExprId e) const { switch (tensorExps[e].kind) { // Leaf. case kTensor: @@ -654,8 +666,10 @@ case kInvariant: llvm::dbgs() << "invariant"; break; - case kIndex: - llvm::dbgs() << "index_" << tensorExps[e].index; + case kLoopVar: + // TODO(wrengr): should we change this to "loop_" or "loopvar_"? Will that + // break any tests? + llvm::dbgs() << "index_" << tensorExps[e].loop; break; // Unary operations. case kAbsF: @@ -725,7 +739,7 @@ } } -void Merger::dumpLat(unsigned p) const { +void Merger::dumpLat(LatPointId p) const { llvm::dbgs() << "lat("; dumpBits(latPoints[p].bits); llvm::dbgs() << " :"; @@ -735,9 +749,9 @@ llvm::dbgs() << " )\n"; } -void Merger::dumpSet(unsigned s) const { +void Merger::dumpSet(LatSetId s) const { llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; - for (unsigned p : latSets[s]) { + for (const LatPointId p : latSets[s]) { llvm::dbgs() << " "; dumpLat(p); } @@ -745,11 +759,11 @@ } void Merger::dumpBits(const BitVector &bits) const { - for (unsigned b = 0, be = bits.size(); b < be; b++) { + for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { if (bits[b]) { - unsigned t = tensor(b); - unsigned i = index(b); - DimLevelType dlt = dimTypes[t][i]; + const TensorId t = tensor(b); + const LoopId i = loop(b); + const auto dlt = lvlTypes[t][i]; llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } } @@ -761,20 +775,20 @@ // Builder methods. //===----------------------------------------------------------------------===// -unsigned Merger::buildLattices(unsigned e, unsigned i) { - Kind kind = tensorExps[e].kind; +LatSetId Merger::buildLattices(ExprId e, LoopId i) { + const Kind kind = tensorExps[e].kind; switch (kind) { // Leaf. case kTensor: case kInvariant: - case kIndex: { - // Either the index is really used in the tensor expression, or it is - // set to the undefined index in that dimension. An invariant expression, + case kLoopVar: { + // Either the loop-var is really used in the tensor expression, or it is + // set to the undefined loop-var in that level. An invariant expression, // a proper index value, and a truly dynamic sparse output tensor are set // to a synthetic tensor with undefined indices only to ensure the // iteration space is not skipped as a result of their contents. - unsigned s = addSet(); - unsigned t = syntheticTensor; + const LatSetId s = addSet(); + TensorId t = syntheticTensor; if (kind == kTensor) { t = tensorExps[e].tensor; if (hasSparseOut && t == outTensor) @@ -836,7 +850,7 @@ // ----+----------+------------+ // | absent() | present(y) | { - unsigned child0 = buildLattices(tensorExps[e].children.e0, i); + const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i); UnaryOp unop = cast(tensorExps[e].op); Region &absentRegion = unop.getAbsentRegion(); @@ -848,7 +862,7 @@ Block &absentBlock = absentRegion.front(); YieldOp absentYield = cast(absentBlock.getTerminator()); Value absentVal = absentYield.getResult(); - unsigned rhs = addExp(kInvariant, absentVal); + const ExprId rhs = addExp(kInvariant, absentVal); return takeDisj(kind, child0, buildLattices(rhs, i), unop); } // Binary operations. @@ -925,8 +939,8 @@ // !x | empty | right(y) | // x | left(x) | overlap(x,y) | { - unsigned child0 = buildLattices(tensorExps[e].children.e0, i); - unsigned child1 = buildLattices(tensorExps[e].children.e1, i); + const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i); + const LatSetId child1 = buildLattices(tensorExps[e].children.e1, i); BinaryOp binop = cast(tensorExps[e].op); Region &leftRegion = binop.getLeftRegion(); Region &rightRegion = binop.getRightRegion(); @@ -957,7 +971,7 @@ llvm_unreachable("unexpected expression kind"); } -std::optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { +std::optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { // Build the linalg semantics backward from yield. Operation *yield = op.getRegion().front().getTerminator(); assert(isa(yield)); @@ -965,7 +979,7 @@ } /// Only returns false if we are certain this is a nonzero. -bool Merger::maybeZero(unsigned e) const { +bool Merger::maybeZero(ExprId e) const { if (tensorExps[e].kind == kInvariant) { if (auto c = tensorExps[e].val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); @@ -980,11 +994,11 @@ return true; } -bool Merger::isInvariant(unsigned e) const { +bool Merger::isInvariant(ExprId e) const { return tensorExps[e].kind == kInvariant; } -Type Merger::inferType(unsigned e, Value src) { +Type Merger::inferType(ExprId e, Value src) const { // Obtain the destination type from the cast node. Type dtp = tensorExps[e].val.getType(); // Inspect source type. For vector types, apply the same @@ -997,7 +1011,7 @@ /// Ensures that sparse compiler can generate code for expression. static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { // Arguments are always admissible. - if (auto arg = v.dyn_cast()) + if (v.isa()) return true; // Accept index anywhere. Operation *def = v.getDefiningOp(); @@ -1024,9 +1038,9 @@ return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); } -std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { +std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { - unsigned argN = arg.getArgNumber(); + const TensorId argN = arg.getArgNumber(); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. @@ -1047,13 +1061,13 @@ // Construct index operations. if (def->getNumOperands() == 0) { if (auto indexOp = dyn_cast(def)) - return addExp(kIndex, indexOp.getDim()); + return addExp(kLoopVar, indexOp.getDim()); } // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { - auto x = buildTensorExp(op, def->getOperand(0)); + const auto x = buildTensorExp(op, def->getOperand(0)); if (x.has_value()) { - unsigned e = *x; + const ExprId e = *x; if (isa(def)) return addExp(kAbsF, e); if (isa(def)) @@ -1129,11 +1143,11 @@ // See buildLattices() for an explanation of rejecting certain // division and shift operations. if (def->getNumOperands() == 2) { - auto x = buildTensorExp(op, def->getOperand(0)); - auto y = buildTensorExp(op, def->getOperand(1)); + const auto x = buildTensorExp(op, def->getOperand(0)); + const auto y = buildTensorExp(op, def->getOperand(1)); if (x.has_value() && y.has_value()) { - unsigned e0 = *x; - unsigned e1 = *y; + const ExprId e0 = *x; + const ExprId e1 = *y; if (isa(def)) return addExp(kMulF, e0, e1); if (isa(def)) @@ -1184,12 +1198,12 @@ } // Construct ternary operations if subexpressions can be built. if (def->getNumOperands() == 3) { - auto x = buildTensorExp(op, def->getOperand(0)); - auto y = buildTensorExp(op, def->getOperand(1)); - auto z = buildTensorExp(op, def->getOperand(2)); + const auto x = buildTensorExp(op, def->getOperand(0)); + const auto y = buildTensorExp(op, def->getOperand(1)); + const auto z = buildTensorExp(op, def->getOperand(2)); if (x.has_value() && y.has_value() && z.has_value()) { - unsigned e0 = *x; - unsigned e1 = *y; + const ExprId e0 = *x; + const ExprId e1 = *y; if (auto redop = dyn_cast(def)) { if (isAdmissibleBranch(redop, redop.getRegion())) return addExp(kReduce, e0, e1, Value(), def); @@ -1245,13 +1259,13 @@ return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); } -Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, - Value v0, Value v1) { +Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, + Value v1) { switch (tensorExps[e].kind) { // Leaf. case kTensor: case kInvariant: - case kIndex: + case kLoopVar: llvm_unreachable("unexpected non-op"); // Unary operations. case kAbsF: