diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -13,6 +13,8 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ +#include "mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h" + #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -23,65 +25,6 @@ namespace mlir { namespace sparse_tensor { -// TODO: These type aliases currently only serve to make the code more -// self-documenting, however because they are not type-checked they can -// do nothing to prevent mixups. We should really change them from mere -// aliases to actual struct definitions, so that we can type-check them. - -/// Tensor identifiers. The valid set of identifiers is defined by the -/// first argument passed to the `Merger` ctor. -using TensorId = unsigned; - -/// Loop identifiers. The valid set of identifiers is defined by the -/// second two arguments to the `Merger` ctor. -/// -/// These identifiers serve as proxies for the `$dim` argument to -/// `linalg::IndexOp`, however the numerical value of a `LoopId` should -/// not necessarily be equated with the numerical value of the corresponding -/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that -/// is, they identify the loop which binds the loop-variable by counting -/// the enclosing loops from innermost to outermost, starting from zero. -/// Whereas `LoopId` are considered to be arbitrary names for identifying -/// loops; since the `Merger` does not care about the actual ordering of -/// loops, and leaves it up to the `LoopEmitter` to specify the actual -/// loop ordering (`LoopOrd`). -/// -/// TODO: Despite the above claim that `$dim` and `LoopId` need not be -/// numerically equal, some code in the `Merger` class does equate them -/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship -/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their -/// providence. If `LoopId` really is supposed to be equated with `$dim`, -/// then we should change the name to `LoopIdx` or similar, to capture the -/// fact that its numerical value is not invariant when entering/exiting -/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which -/// are invariant identifiers). -using LoopId = unsigned; - -/// A compressed representation of `std::pair`. -/// The compression scheme is such that this also serves as an index -/// into the bitvector stored in `LatPoint` (since that bitvector is -/// just the implementation for a set of `TensorLoopId` values). -using TensorLoopId = unsigned; - -/// `TensorExp` identifiers. These are allocated by `Merger::addExp`, -/// and serve as unique identifiers for the corresponding `TensorExp` object. -using ExprId = unsigned; - -/// `LatPoint` identifiers. These are allocated by `Merger::addLat`, -/// and serve as unique identifiers for the corresponding `LatPoint` object. -using LatPointId = unsigned; - -/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and -/// by other methods calling that one), and serve as unique identifiers -/// for the corresponding `SmallVector` object. -using LatSetId = unsigned; - -namespace detail { -/// A constant serving as the canonically invalid identifier, regardless -/// of the identifier type. -static constexpr unsigned kInvalidId = -1u; -} - /// Tensor expression. Represents an MLIR expression in tensor index notation. struct TensorExp final { enum class Kind; @@ -207,18 +150,17 @@ kReduce, // semiring reduction op }; +//===----------------------------------------------------------------------===// /// Lattice point. Each lattice point consists of a formal conjunction /// of `TensorLoopId`s, together with the identifier of the corresponding /// tensor expression. The formal conjunction is represented as a set of /// `TensorLoopId`, where that set is implemented as a `BitVector`. struct LatPoint final { - /// Construct the lattice point from a given set of `TensorLoopId`s. - LatPoint(const BitVector &bits, ExprId e); + /// Construct a lattice point with the empty set of `TensorLoopId`s. + LatPoint(unsigned size, ExprId e) : bits(size, false), exp(e) {} - /// Construct a lattice point with `(t,i)` as the only `TensorLoopId`, - /// where `(t,i) < (numTensors,numLoops)`. - LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i, - ExprId e); + /// Construct a lattice point from the given set of `TensorLoopId`s. + LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {} /// Conjunction of all `TensorLoopId`s involved in the tensor expression. BitVector bits; @@ -232,6 +174,7 @@ ExprId exp; }; +//===----------------------------------------------------------------------===// /// A class to handle all iteration lattice operations. This class abstracts /// away from some implementation details of storing iteration lattices and /// tensor expressions. This allows for fine-tuning performance characteristics @@ -268,17 +211,45 @@ Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, unsigned numFilterLoops); + // + // Constructing valid tensor and loop identifiers. + // + + /// Safely converts the argument to a tensor identifier. + constexpr TensorId makeTensorId(unsigned t) const { + assert(isValidTensorId(t)); + return t; + } + + /// Safely converts the argument to a loop identifier. + constexpr LoopId makeLoopId(unsigned i) const { + assert(isValidLoopId(i)); + return i; + } + + /// Safely converts the arguments to a pair of (tensor,loop) identifiers. + constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { + assert(isValidTensorId(t) && isValidLoopId(i)); + return numTensors * i + t; + } + + // + // Allocating new expressions, points, and sets. + // + /// Constructs a new tensor expression, and returns its identifier. - /// The type of the `e0` argument varies according to the value of the - /// `k` argument, as described by the `TensorExp` ctor. - ExprId addExp(TensorExp::Kind k, unsigned e0, ExprId e1 = detail::kInvalidId, - Value v = Value(), Operation *op = nullptr); - ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr) { - return addExp(k, e, detail::kInvalidId, v, op); - } - ExprId addExp(TensorExp::Kind k, Value v, Operation *op = nullptr) { - return addExp(k, detail::kInvalidId, detail::kInvalidId, v, op); - } + ExprId addTensorExp(TensorId t); + /// Constructs a new loop-variable expression, and returns its identifier. + ExprId addLoopVarExp(LoopId i); + /// Constructs a new invariant expression, and returns its identifier. + ExprId addInvariantExp(Value v); + /// Constructs a new unary or binary expression, and returns its identifier. + ExprId addExp(TensorExp::Kind k, ExprId e0, ExprId e1 = detail::kInvalidId, + Operation *op = nullptr); + /// Constructs a new sesquinary expression, and returns its identifier. + /// Currently no sesquinary `Kind` allows specifying the `op`, but we + /// allow it anyways because `mapSet` is designed to allow it. + ExprId addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op = nullptr); /// Constructs a new iteration lattice point, and returns its identifier. LatPointId addLat(TensorId t, LoopId i, ExprId e); @@ -336,51 +307,43 @@ bool onlyDenseDiff(LatPointId p0, LatPointId p1) const; /// Gets the tensor-identifier of the `TensorLoopId`. - TensorId tensor(TensorLoopId b) const { return b % numTensors; } + constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; } /// Gets the loop-identifier of the `TensorLoopId`. - LoopId loop(TensorLoopId b) const { return b / numTensors; } + constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; } /// Get the total number of tensors (including the output-tensor and - /// synthetic-tensor). The result is given the type `TensorId` since - /// the result is primarily used as an upper bound for `TensorId`s. - TensorId getNumTensors() const { return numTensors; } + /// synthetic-tensor). + constexpr unsigned getNumTensors() const { return numTensors; } /// Get the total number of loops (native loops + filter loops). - /// The result is given the type `LoopId` since the result will - /// generally be used as a for-loop upper bound. - LoopId getNumLoops() const { return numLoops; } - /// Get the number of native loops. The result is given the type - /// `LoopId` since the result will generally be used as a for-loop - /// upper bound. - LoopId getNumNativeLoops() const { return numNativeLoops; } - /// Get the number of filter loops. The result is given the type - /// `LoopId` since the result will generally be used as a for-loop - /// upper bound. - LoopId getNumFilterLoops() const { return numLoops - numNativeLoops; } + constexpr unsigned getNumLoops() const { return numLoops; } + /// Get the number of native loops. + constexpr unsigned getNumNativeLoops() const { return numNativeLoops; } + /// Get the number of filter loops. + constexpr unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; } /// Get the identifier of the first filter-loop. - LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); } + constexpr LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); } /// Returns true if `b` is the `i`th loop of the output tensor. - bool isOutTensor(TensorLoopId b, LoopId i) const { - assert(i < numLoops); - return b == numTensors * i + outTensor; + constexpr bool isOutTensor(TensorLoopId b, LoopId i) const { + return b == makeTensorLoopId(outTensor, i); } /// Get the output tensor's identifier. - TensorId getOutTensorID() const { return outTensor; } + constexpr TensorId getOutTensorID() const { return outTensor; } /// Get the synthetic tensor's identifier (used for all invariant /// tensor expressions). - TensorId getSynTensorID() const { return syntheticTensor; } + constexpr TensorId getSynTensorID() const { return syntheticTensor; } - bool isFilterLoop(LoopId i) const { - assert(i < numLoops); + constexpr bool isFilterLoop(LoopId i) const { + assert(isValidLoopId(i)); return i >= numNativeLoops; } /// Returns true if the expression is `(kTensor t)`. bool expIsTensor(ExprId e, TensorId t) const { - return tensorExps[e].kind == TensorExp::Kind::kTensor && - tensorExps[e].tensor == t; + const auto &expr = exp(e); + return expr.kind == TensorExp::Kind::kTensor && expr.tensor == t; } /// Returns true if the expression contains the tensor as an operand. @@ -408,7 +371,7 @@ /// Gets the level-type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(TensorId t, LoopId i) const { - assert(t < numTensors && i < numLoops); + assert(isValidTensorId(t) && isValidLoopId(i)); return lvlTypes[t][i]; } @@ -419,13 +382,13 @@ /// Gets the loop identifier for the `lvl`th level of the `t`th tensor. std::optional getLoopId(TensorId t, Level lvl) const { - assert(t < numTensors && lvl < lvlToLoop[t].size()); + assert(isValidLevel(t, lvl)); return lvlToLoop[t][lvl]; } /// Gets the level number of the the `t`th tensor on `i`th loop. std::optional getLvl(TensorId t, LoopId i) const { - assert(t < numTensors && i < numLoops); + assert(isValidTensorId(t) && isValidLoopId(i)); return loopToLvl[t][i]; } std::optional getLvl(TensorLoopId b) const { @@ -435,31 +398,42 @@ /// Sets the level number and level-type of the `t`th tensor on /// `i`th loop. void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType dlt) { - assert(t < numTensors && i < numLoops && lvl < lvlToLoop[t].size() && - isValidDLT(dlt)); + assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidDLT(dlt)); lvlTypes[t][i] = dlt; loopToLvl[t][i] = lvl; lvlToLoop[t][lvl] = i; } + using ForeachTensorLoopIdCallback = function_ref, DimLevelType, bool)>; + /// Iterates over a set of `TensorLoopId`s, invoking the callback /// for each `TensorLoopId` and passing it the corresponding tensor /// identifier, level, and level-type, following with a boolean value /// indicating whether it is a dependent index reduction loop condition. + void foreachTensorLoopId(LatPointId p, ForeachTensorLoopIdCallback callback) const { + // TODO: the default ought to be simple=true; but we'll need to make + // sure to update all the tests to make sure they do the right thing. + foreachTensorLoopId(p, /*simple=*/false, callback); + } void foreachTensorLoopId( - LatPointId p, function_ref, DimLevelType, bool)> - callback) { - for (const TensorLoopId b : latPoints[p].bits.set_bits()) { - TensorId t = tensor(b); + LatPointId p, + bool simple, + ForeachTensorLoopIdCallback callback) const { + const auto &point = lat(p); + const auto &bits = simple ? point.simple : point.bits; + for (const TensorLoopId b : bits.set_bits()) { + const TensorId t = tensor(b); + const auto optLvl = getLvl(b); + const auto lvlTp = getDimLevelType(b); if (isLvlWithNonTrivialIdxExp(b)) { // This must be an undefined level. - assert(!getLvl(b).has_value()); + assert(!optLvl.has_value()); // Slice the tid along the dependent level to iterate current loop. - callback(b, t, loopToDependencies[loop(b)][t], getDimLevelType(b), + callback(b, t, loopToDependencies[loop(b)][t], lvlTp, /*isIdxReduc=*/true); } else { - callback(b, t, getLvl(b), getDimLevelType(b), /*isIdxReduc=*/false); + callback(b, t, optLvl, lvlTp, /*isIdxReduc=*/false); } } } @@ -469,31 +443,37 @@ /// Establishes the two-way map that i <-> . void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) { - assert(lvl < numLoops); + assert(isValidLoopId(i) && isValidLevel(t, lvl)); loopToDependencies[i][t] = lvl; levelToDependentIdx[t][lvl].push_back(i); } /// Whether the loop has dependent slice. - bool hasDependentLvl(LoopId i, TensorId tid) { - return loopToDependencies[i][tid].has_value(); + bool hasDependentLvl(LoopId i, TensorId t) { + assert(isValidTensorId(t) && isValidLoopId(i)); + return loopToDependencies[i][t].has_value(); } /// Returns the list of loop indices which appear in the non-trivial index /// expression on t_l, e.g., A[i+j] => {i, j} std::vector &getDependentLoops(TensorId t, Level lvl) { + assert(isValidLevel(t, lvl)); return levelToDependentIdx[t][lvl]; } /// Returns the defining [tid, lvl] for the loop. std::pair getLoopDefiningLvl(LoopId i) const { + assert(isValidLoopId(i)); return loopBounds[i]; } /// Checks whether the TensorLoopId represents a tensor level with /// non-trivial index expression on it. bool isLvlWithNonTrivialIdxExp(TensorLoopId b) const { - return loopToDependencies[loop(b)][tensor(b)].has_value(); + const TensorId t = tensor(b); + const LoopId i = loop(b); + assert(isValidTensorId(t) && isValidLoopId(i)); + return loopToDependencies[i][t].has_value(); } /// Convenience getters to immediately access the stored nodes. @@ -509,20 +489,28 @@ /// references, but also applies to the `ArrayRef`. In particular, /// using `for (LatPointId p : merger.set(s))` will run into the same /// dangling-reference problems if the loop body inserts new sets. - const TensorExp &exp(ExprId e) const { return tensorExps[e]; } - const LatPoint &lat(LatPointId p) const { return latPoints[p]; } - ArrayRef set(LatSetId s) const { return latSets[s]; } + const TensorExp &exp(ExprId e) const { + assert(isValidExprId(e)); + return tensorExps[e]; + } + const LatPoint &lat(LatPointId p) const { + assert(isValidLatPointId(p)); + return latPoints[p]; + } + ArrayRef set(LatSetId s) const { + assert(isValidLatSetId(s)); + return latSets[s]; + } /// Checks whether the given expression has an associated value. - bool hasExprValue(ExprId e) const { - return static_cast(tensorExps[e].val); - } + bool hasExprValue(ExprId e) const { return static_cast(exp(e).val); } /// Sets the expression to have the associated value. Asserts that /// the new value is defined, and that the expression does not already /// have a value. If you want to overwrite a previous associated value, /// use `updateExprValue` instead. void setExprValue(ExprId e, Value v) { + assert(isValidExprId(e)); assert(v && "Got an undefined value"); auto &val = tensorExps[e].val; assert(!val && "Expression already has an associated value"); @@ -534,6 +522,7 @@ /// If you don't want to check for a previous associated value first, /// then use `updateExprValue` instead. void clearExprValue(ExprId e) { + assert(isValidExprId(e)); auto &val = tensorExps[e].val; assert(val && "Expression does not have an associated value to clear"); val = Value(); @@ -550,7 +539,10 @@ // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those // provide better invariants. - void updateExprValue(ExprId e, Value v) { tensorExps[e].val = v; } + void updateExprValue(ExprId e, Value v) { + assert(isValidExprId(e)); + tensorExps[e].val = v; + } #ifndef NDEBUG /// Print methods (for debugging). @@ -575,8 +567,28 @@ private: /// Private helpers. + constexpr bool isValidTensorId(TensorId t) const { + return t < numTensors; + } + constexpr bool isValidLoopId(LoopId i) const { + return i != detail::kInvalidId && i < numLoops; + } + bool isValidLevel(TensorId t, Level lvl) const { + return isValidTensorId(t) && lvl < lvlToLoop[t].size(); + } + bool isValidExprId(ExprId e) const { + return e != detail::kInvalidId && e < tensorExps.size(); + } + bool isValidLatPointId(LatPointId p) const { + return p != detail::kInvalidId && p < latPoints.size(); + } + bool isValidLatSetId(LatSetId s) const { + return s != detail::kInvalidId && s < latSets.size(); + } bool maybeZero(ExprId e) const; - bool isInvariant(ExprId e) const; + bool isInvariant(ExprId e) const { + return exp(e).kind == TensorExp::Kind::kInvariant; + } Type inferType(ExprId e, Value src) const; /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h @@ -0,0 +1,97 @@ +//===- MergerNewtypes.h - Newtypes for the `Merger` class -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TODO: This header currently defines some typedefs to avoid confusion +// between several different things which are all represented as `unsigned`. +// Over the next few commits, these typedefs will be replaced with "newtypes" +// (i.e., data types which are zero-cost abstractions for wrapping some +// underlying type while ensuring that the compiler keeps the new type +// distinct from the old type), along with related classes for iterating +// over them, etc. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ +#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ + +#include +#include + +namespace mlir { +namespace sparse_tensor { + +namespace detail { +/// A constant serving as the canonically invalid identifier, +/// regardless of the identifier type. +static constexpr unsigned kInvalidId = -1u; +} // namespace detail + +//===----------------------------------------------------------------------===// +/// Tensor identifiers. +/// +/// Semantically, tensor identifiers could be chosen to be anything; +/// but operationally, they must be chosen such that the `Merger` +/// and `GenericOpSparsifier` agree. Therefore, the numeric values of +/// tensor identifiers are chosen to be the `BlockArgument::getArgNumber` +/// of the value passed to `Merger::buildTensorExp`, which ranges from +/// zero to `linalg::GenericOp::getNumOperands` for the op passed to +/// `GenericOpSparsifier::matchAndRewrite`. +using TensorId = unsigned; + +//===----------------------------------------------------------------------===// +/// Loop identifiers. +/// +/// These identifiers serve as proxies for the `$dim` argument to +/// `linalg::IndexOp`, however the numerical value of a `LoopId` should +/// not necessarily be equated with the numerical value of the corresponding +/// `$dim` argument. The `$dim` arguments are De Bruijn indices: that +/// is, they identify the loop which binds the loop-variable by counting +/// the enclosing loops from innermost to outermost, starting from zero. +/// Whereas `LoopId` are considered to be arbitrary names for identifying +/// loops; since the `Merger` does not care about the actual ordering of +/// loops, and leaves it up to the `LoopEmitter` to specify the actual +/// loop ordering (`LoopOrd`). +/// +/// TODO: Despite the above claim that `$dim` and `LoopId` need not be +/// numerically equal, some code in the `Merger` class does equate them +/// (e.g., `buildTensorExp`). So we need to explicate the exact relationship +/// between `$dim`, `LoopId`, and `LoopOrd`; especially with regards to their +/// providence. If `LoopId` really is supposed to be equated with `$dim`, +/// then we should change the name to `LoopIdx` or similar, to capture the +/// fact that its numerical value is not invariant when entering/exiting +/// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which +/// are invariant identifiers). +using LoopId = unsigned; + +//===----------------------------------------------------------------------===// +/// A compressed representation of `std::pair`. +/// The compression scheme is such that this also serves as an index +/// into the bitvector stored in `LatPoint` (since that bitvector is +/// just the implementation for a set of `TensorLoopId` values). +using TensorLoopId = unsigned; + +//===----------------------------------------------------------------------===// +/// `TensorExp` identifiers. These are allocated by `Merger::addExp`, +/// and serve as unique identifiers for the corresponding `TensorExp` object. +using ExprId = unsigned; + +//===----------------------------------------------------------------------===// +/// `LatPoint` identifiers. These are allocated by `Merger::addLat`, +/// and serve as unique identifiers for the corresponding `LatPoint` object. +using LatPointId = unsigned; + +//===----------------------------------------------------------------------===// +/// `LatSet` identifiers. These are allocated by `Merger::addSet` (and +/// by other methods calling that one), and serve as unique identifiers +/// for the corresponding `SmallVector` object. +using LatSetId = unsigned; + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -65,6 +65,15 @@ // Merger delegates. // + constexpr TensorId makeTensorId(unsigned t) const { + return latticeMerger.makeTensorId(t); + } + constexpr LoopId makeLoopId(unsigned i) const { + return latticeMerger.makeLoopId(i); + } + constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { + return latticeMerger.makeTensorLoopId(t, i); + } const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); } const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } ArrayRef set(LatSetId s) const { return latticeMerger.set(s); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -86,11 +86,13 @@ SmallVector tensors; // input tensors passed to loop emitter for (OpOperand &t : linalgOp->getOpOperands()) { tensors.push_back(t.get()); - Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); - for (Level lvl = 0; lvl < rank; lvl++) { - sortArrayBasedOnOrder( - latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort); - } + const TensorId tid = makeTensorId(t.getOperandNumber()); + const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); + const auto enc = getSparseTensorEncoding(t.get().getType()); + (void)enc; + assert(!enc || lvlRank == enc.getLvlRank()); + for (Level lvl = 0; lvl < lvlRank; lvl++) + sortArrayBasedOnOrder(latticeMerger.getDependentLoops(tid, lvl), topSort); } loopEmitter.initialize( @@ -162,10 +164,7 @@ } OpOperand *lhs = linalgOp.getDpsInitOperand(0); - // That the operand number is a valid `TensorId` will be verified - // by the call to `isSingleCondition` below; though we may want to add - // assertions to check it here, in order to give better error messages. - const TensorId tensor = lhs->getOperandNumber(); + const TensorId tensor = makeTensorId(lhs->getOperandNumber()); // An non-annotated output tensor is assumed dense, and becomes a random // access n-dim memref. Admissible since insertions cannot occur. if (getSparseTensorType(lhs->get()).isAllDense()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -255,10 +255,10 @@ Location loc, Value crd, TensorId tid, Level lvl); - TensorId getNumTensors() const { return tensors.size(); } + unsigned getNumTensors() const { return tensors.size(); } bool isOutputTensor(TensorId tid) const { - return hasOutput && tid == static_cast(getNumTensors() - 1); + return hasOutput && tid == getNumTensors() - 1; } bool isSparseOutput(TensorId tid) const { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -221,7 +221,7 @@ this->hasOutput = hasOutput; this->isSparseOut = isSparseOut; - const TensorId numTensors = ts.size(); + const unsigned numTensors = ts.size(); this->tensors.assign(ts.begin(), ts.end()); this->lvlTypes.assign(numTensors, std::vector()); this->lvlSizes.assign(numTensors, std::vector()); @@ -414,8 +414,9 @@ // level-expression, the `getPosition` must in fact be a `Dimension`. // However, elsewhere we have been lead to expect that `loopIdToOrd` // should be indexed by `LoopId`... - const LoopId i = a.cast().getPosition(); - return loopStack[loopIdToOrd[i]].iv; + const auto loopId = a.cast().getPosition(); + assert(loopId < loopIdToOrd.size()); + return loopStack[loopIdToOrd[loopId]].iv; } case AffineExprKind::Add: { auto binOp = a.cast(); @@ -515,8 +516,8 @@ assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); // For COO, the position is the same across consecutive levels. /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - llvm::for_each(reassoc, - [this, tid, iv](Level srcLvl) { posits[tid][srcLvl] = iv; }); + llvm::for_each( + reassoc, [this, tid, iv](Level srcLvl) { posits[tid][srcLvl] = iv; }); crd = genSparseCrd(builder, loc, tid, dstLvl); } else { // Dense tensor, the coordinate is the inducation variable. @@ -686,7 +687,7 @@ Value cond; unsigned o = 0; for (auto [t, lvl] : llvm::zip(tids, lvls)) { - unsigned tid = t; // Why `t` can not be captured by lambda? + const TensorId tid = t; // Why `t` can not be captured by lambda? const auto lvlTp = lvlTypes[tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, lvl); @@ -825,9 +826,9 @@ { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(ifNewSegHi.thenBlock()); - builder.create(loc, - genSegmentHigh(builder, loc, tid, srcLvl, - pos, highs[tid][srcLvl])); + builder.create( + loc, genSegmentHigh(builder, loc, tid, srcLvl, pos, + highs[tid][srcLvl])); // Else, resues the same segment high. builder.setInsertionPointToStart(ifNewSegHi.elseBlock()); builder.create(loc, oldSegHi); @@ -890,10 +891,11 @@ // guarantee that segHi is defined: because we only generate segHi // whenever coiterating, in order to improve code quality for the // non-coiterating cases. - const auto theSegHi = segHi[tid][srcLvl - 1]; - highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && theSegHi) - ? theSegHi - : builder.create(loc, pLo, c1); + const auto parentSegHi = segHi[tid][srcLvl - 1]; + highs[tid][srcLvl] = + (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && parentSegHi) + ? parentSegHi + : builder.create(loc, pLo, c1); return; } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -943,11 +943,12 @@ // one for loop? // FIXME(wrengr): what is this "ld" supposed to be really? const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l; - loopEmitter.enterNewLoopSeq(rewriter, loc, 0, ld); + const SmallVector tids{0}; + loopEmitter.enterNewLoopSeq(rewriter, loc, tids, ld); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, 0, l, reduc); + loopEmitter.enterLoopOverTensorAtLvl(rewriter, loc, tids, l, reduc); } SmallVector lcvs; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -225,7 +225,7 @@ bool setLvlFormat = true) { switch (a.getKind()) { case AffineExprKind::DimId: { - const LoopId idx = a.cast().getPosition(); + const LoopId idx = merger.makeLoopId(a.cast().getPosition()); if (!isUndefDLT(merger.getDimLevelType(tid, idx))) return false; // used more than once @@ -239,7 +239,8 @@ if (!isDenseDLT(dlt) && setLvlFormat) { assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx))); // Use a filter loop for sparse affine expression. - merger.setLevelAndType(tid, filterLdx++, lvl, dlt); + merger.setLevelAndType(tid, filterLdx, lvl, dlt); + ++filterLdx; } if (auto binOp = a.dyn_cast()) { @@ -279,7 +280,7 @@ bool isSubExp = false) { switch (a.getKind()) { case AffineExprKind::DimId: { - LoopId ldx = a.cast().getPosition(); + const LoopId ldx = merger.makeLoopId(a.cast().getPosition()); if (!isUndefDLT(merger.getDimLevelType(tensor, ldx))) return false; // used more than once, e.g., A[i][i] @@ -408,6 +409,7 @@ // `filterLdx` may be mutated by `findAffine`. LoopId filterLdx = env.merger().getStartingFilterLoopId(); for (OpOperand &t : env.op()->getOpOperands()) { + const TensorId tid = env.makeTensorId(t.getOperandNumber()); const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) @@ -426,9 +428,9 @@ // If then current tensor being inspected requires affine index, it need // to be sliced. for (Level l = 0; l < lvlRank; l++) { - const TensorId tid = t.getOperandNumber(); - AffineExpr a = map.getResult(toOrigDim(enc, l)); - DimLevelType dlt = enc.getLvlType(l); + // FIXME: `toOrigDim` is deprecated. + const AffineExpr a = map.getResult(toOrigDim(enc, l)); + const DimLevelType dlt = enc.getLvlType(l); if (idxReducBased && needIdxReduc) { if (!findDepIdxSet(env.merger(), tid, l, a, dlt)) return false; // inadmissible affine expression @@ -445,19 +447,19 @@ /// A helper to compute a topological sort. O(n^2) time complexity /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the -/// latest possible index. -/// FIXME(wrengr): correct the above "index" +/// latest possible `LoopOrd`. /// /// The `inDegree` is indexed by `LoopId`, and the `adjM` is indexed by /// `(LoopId,LoopId)`. -static bool topSortOptimal(CodegenEnv &env, LoopId n, +static bool topSortOptimal(CodegenEnv &env, ArrayRef iteratorTypes, std::vector &inDegree, std::vector> &adjM) { std::vector redIt; // reduce iterator with 0 degree std::vector parIt; // parallel iterator with 0 degree std::vector filterIt; // filter loop with 0 degree - for (LoopId i = 0; i < n; i++) { + const LoopId numLoops = env.merger().getNumLoops(); + for (LoopId i = 0; i < numLoops; i++) { if (inDegree[i] == 0) { if (env.merger().isFilterLoop(i)) filterIt.push_back(i); @@ -493,7 +495,7 @@ env.topSortPushBack(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. - for (LoopId dst = 0; dst < n; dst++) { + for (LoopId dst = 0; dst < numLoops; dst++) { if (adjM[src][dst] && --inDegree[dst] == 0) { if (env.merger().isFilterLoop(dst)) filterIt.push_back(dst); @@ -504,7 +506,7 @@ } } } - return env.topSortSize() == n; + return env.topSortSize() == numLoops; } /// Helper method to add all constraints from the indices in one affine @@ -535,7 +537,7 @@ const auto toExpand = a ? a : b; switch (toExpand.getKind()) { case AffineExprKind::DimId: { - std::optional idx = toExpand.cast().getPosition(); + const std::optional idx{toExpand.cast().getPosition()}; if (toExpand == a) addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx); else // toExpand == b @@ -597,22 +599,25 @@ OpOperand *skip, SortMask mask, std::vector> &adjM, std::vector &inDegree) { - // Get map and encoding. - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + // Get map, encoding, and tensor-identifier. + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); + const TensorId tid = env.makeTensorId(t.getOperandNumber()); // Each tensor expression and optional dimension ordering (row-major // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. - for (Level l = 0, rank = map.getNumResults(); l < rank; l++) { - AffineExpr ta = map.getResult(toOrigDim(enc, l)); - std::optional tldx = - env.merger().getLoopId(t.getOperandNumber(), l); + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + for (Level lvl = 0; lvl < lvlRank; lvl++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr ta = map.getResult(toOrigDim(enc, lvl)); + std::optional tldx = env.merger().getLoopId(tid, lvl); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[l])); + assert(!ta.isa() && !isDenseDLT(enc.getDimLevelType()[lvl])); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); // Now that the ordering of affine expression is captured by filter // loop idx, we only need to ensure the affine ordering against filter @@ -626,10 +631,10 @@ if (&t == skip) continue; - if (l > 0) { - AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); - std::optional fldx = - env.merger().getLoopId(t.getOperandNumber(), l - 1); + if (lvl > 0) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1)); + std::optional fldx = env.merger().getLoopId(tid, lvl - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -657,8 +662,8 @@ std::vector> &adjM, std::vector &inDegree) { // Get map and encoding. - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); // No special treatment for simple indices. if (getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) == 0) @@ -674,19 +679,22 @@ // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6], // we requires there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6}, // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y - for (Level lvl = 1, rank = map.getNumResults(); lvl < rank; lvl++) { - AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1)); - AffineExpr ta = map.getResult(toOrigDim(enc, lvl)); + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + for (Level lvl = 1; lvl < lvlRank; lvl++) { + // FIXME: `toOrigDim` is deprecated. + const AffineExpr fa = map.getResult(toOrigDim(enc, lvl - 1)); + const AffineExpr ta = map.getResult(toOrigDim(enc, lvl)); // This is a heuristic, we pick an abitrary reduction loop from lhs and // rhs and use them as d_x and d_y. finder.walkPostOrder(fa); - AffineDimExpr fexp = finder.getDimExpr(); - LoopId fldx = fexp.getPosition(); + const AffineDimExpr fexp = finder.getDimExpr(); + const LoopId fldx = env.makeLoopId(fexp.getPosition()); finder.walkPostOrder(ta); - AffineDimExpr texp = finder.getDimExpr(); - LoopId tldx = texp.getPosition(); + const AffineDimExpr texp = finder.getDimExpr(); + const LoopId tldx = env.makeLoopId(texp.getPosition()); // d_x > d_y if (!adjM[fldx][tldx]) { @@ -701,7 +709,7 @@ // make sure dx and dy is the last; for (auto fd : fCollector.dims) { - LoopId f = fd.getPosition(); + const LoopId f = env.makeLoopId(fd.getPosition()); if (f == fldx) continue; if (!adjM[f][fldx]) { @@ -710,7 +718,7 @@ } } for (auto td : tCollector.dims) { - LoopId t = td.getPosition(); + const LoopId t = env.makeLoopId(td.getPosition()); if (t == tldx) continue; if (!adjM[t][tldx]) { @@ -728,12 +736,12 @@ // TODO: the evaluation order need to be ensure to // support affine multiplication. for (auto fd : fCollector.dims) { - LoopId f = fd.getPosition(); + const LoopId f = env.makeLoopId(fd.getPosition()); if (f == fldx) // skip d_x continue; for (auto td : tCollector.dims) { - LoopId t = td.getPosition(); + const LoopId t = env.makeLoopId(td.getPosition()); if (t == tldx) // skip d_y continue; if (!adjM[f][t]) { @@ -755,9 +763,9 @@ OpOperand *skip, bool idxReducBased = false) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - const LoopId n = env.merger().getNumLoops(); - std::vector> adjM(n, std::vector(n, false)); - std::vector inDegree(n, 0); // in-degree of each node. + const unsigned numLoops = env.merger().getNumLoops(); + std::vector> adjM(numLoops, std::vector(numLoops, false)); + std::vector inDegree(numLoops, 0); // in-degree of each node. const auto iteratorTypes = env.op().getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand &t : env.op()->getOpOperands()) { @@ -765,7 +773,7 @@ const auto enc = getSparseTensorEncoding(t.get().getType()); assert(env.op().getMatchingIndexingMap(&t).getNumDims() + getNumNonTrivialIdxExpOnSparseLvls(env.op()) == - n); + numLoops); // Skips dense inputs/outputs when not requested. const bool isDenseInput = !enc && env.op().isDpsInput(&t); @@ -778,12 +786,12 @@ // will be skipped more often. // TODO: Do we really need this? if (includesUndef(mask)) { - const TensorId tensor = t.getOperandNumber(); - for (LoopId i = 0; i < n; i++) { - const auto dltI = env.dlt(tensor, i); + const TensorId tid = env.makeTensorId(t.getOperandNumber()); + for (LoopId i = 0; i < numLoops; i++) { + const auto dltI = env.dlt(tid, i); if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) { - for (LoopId j = 0; j < n; j++) - if (isUndefDLT(env.dlt(tensor, j))) { + for (LoopId j = 0; j < numLoops; j++) + if (isUndefDLT(env.dlt(tid, j))) { adjM[i][j] = true; inDegree[j]++; } @@ -801,8 +809,8 @@ } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. - env.topSortClear(n); - return topSortOptimal(env, n, iteratorTypes, inDegree, adjM); + env.topSortClear(numLoops); + return topSortOptimal(env, iteratorTypes, inDegree, adjM); } //===----------------------------------------------------------------------===// @@ -856,16 +864,16 @@ // a "coordinate", or "Ldx", or what). So the function should be renamed // and/or the documentation expanded in order to clarify. static Value genIndex(CodegenEnv &env, OpOperand *t) { - auto map = env.op().getMatchingIndexingMap(t); + const auto map = env.op().getMatchingIndexingMap(t); const auto stt = getSparseTensorType(t->get()); const Level lvlRank = stt.getLvlRank(); assert(static_cast(map.getNumResults()) == lvlRank); // FIXME: `toOrigDim` is deprecated. // FIXME: above we asserted that there are `lvlRank` many results, // but this is assuming there are in fact `dimRank` many results instead. - AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1)); + const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1)); assert(a.getKind() == AffineExprKind::DimId); - const LoopId idx = a.cast().getPosition(); + const LoopId idx = env.makeLoopId(a.cast().getPosition()); return env.getLoopVar(idx); } @@ -873,7 +881,7 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl &args) { const Location loc = env.op().getLoc(); - const TensorId tid = t->getOperandNumber(); + const TensorId tid = env.makeTensorId(t->getOperandNumber()); const auto map = env.op().getMatchingIndexingMap(t); const auto stt = getSparseTensorType(t->get()); if (stt.hasEncoding()) { @@ -1092,7 +1100,7 @@ Value e, LoopId ldx) { if (Operation *def = e.getDefiningOp()) { if (auto indexOp = dyn_cast(def)) - return env.getLoopVar(indexOp.getDim()); + return env.getLoopVar(env.makeLoopId(indexOp.getDim())); if (def->getBlock() == block) { for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { rewriter.updateRootInPlace(def, [&]() { @@ -1153,7 +1161,7 @@ bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId; linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); - auto map = op.getMatchingIndexingMap(&t); + const auto map = op.getMatchingIndexingMap(&t); const auto stt = getSparseTensorType(t.get()); const Level lvlRank = stt.getLvlRank(); assert(static_cast(map.getNumResults()) == lvlRank); @@ -1161,8 +1169,9 @@ // FIXME: `toOrigDim` is deprecated. // FIXME: above we asserted that there are `lvlRank` many results, // but this is assuming there are in fact `dimRank` many results instead. - AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l)); - const auto sldx = env.merger().getLoopId(t.getOperandNumber(), l); + const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l)); + const auto sldx = env.merger().getLoopId( + env.makeTensorId(t.getOperandNumber()), l); if (sldx && env.merger().isFilterLoop(*sldx)) { if (!env.getLoopVar(*sldx)) // The filter loops has not been constructed. @@ -1386,29 +1395,28 @@ /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx, - const BitVector &conditions) { + LatPointId p) { Location loc = env.op().getLoc(); SmallVector types; Value cond; - for (TensorLoopId b = 0, be = conditions.size(); b < be; b++) { - if (!conditions[b]) - continue; - const TensorId tid = env.merger().tensor(b); - assert(ldx == env.merger().loop(b)); - Value clause; - const auto dlt = env.dlt(b); - if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { - const Level lvl = *env.merger().getLvl(tid, ldx); - const Value crd = env.emitter().getCoords()[tid][lvl]; - const Value lvar = env.getLoopVar(ldx); - clause = builder.create(loc, arith::CmpIPredicate::eq, crd, - lvar); - } else { - assert(isDenseDLT(dlt) || isUndefDLT(dlt)); - clause = constantI1(builder, loc, true); - } - cond = cond ? builder.create(loc, cond, clause) : clause; - } + env.merger().foreachTensorLoopId( + p, /*simple=*/true, + [&](TensorLoopId b, TensorId tid, std::optional lvl, + DimLevelType dlt, bool /*unused*/) { + assert(ldx == env.merger().loop(b)); + Value clause; + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + assert(lvl.has_value()); + const Value crd = env.emitter().getCoords()[tid][*lvl]; + const Value lvar = env.getLoopVar(ldx); + clause = builder.create(loc, arith::CmpIPredicate::eq, + crd, lvar); + } else { + assert(isDenseDLT(dlt) || isUndefDLT(dlt)); + clause = constantI1(builder, loc, true); + } + cond = cond ? builder.create(loc, cond, clause) : clause; + }); if (env.isReduc()) { types.push_back(env.getReduc().getType()); if (env.getValidLexInsert()) @@ -1505,7 +1513,7 @@ const auto enc = getSparseTensorEncoding(input->get().getType()); if (enc) { const Location loc = op.getLoc(); - const TensorId tid = input->getOperandNumber(); + const TensorId tid = env.makeTensorId(input->getOperandNumber()); const Level lvlRank = enc.getLvlRank(); assert(lvlExprs.size() == static_cast(lvlRank)); // FIXME: there is dim/lvl confusion here @@ -1545,7 +1553,7 @@ env.merger().foreachTensorLoopId( li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt, bool isIdxReduc) { - if (simple.test(b)) { + if (simple[b]) { if (isIdxReduc) { tids.push_back(tid); lvls.push_back(*lvl); @@ -1634,8 +1642,8 @@ /// Starts a single loop in current sequence. static std::pair startLoop(CodegenEnv &env, - OpBuilder &builder, unsigned at, - unsigned li, bool needsUniv) { + OpBuilder &builder, LoopOrd at, + LatPointId li, bool needsUniv) { // The set of tensors + lvls to generate loops on SmallVector tids, affineTids; SmallVector lvls, affineLvls; @@ -1747,7 +1755,7 @@ if (li == lj || env.merger().latGT(li, lj)) { // Recurse into body of each branch. if (!isSingleCond) { - scf::IfOp ifOp = genIf(env, rewriter, idx, env.lat(lj).simple); + scf::IfOp ifOp = genIf(env, rewriter, idx, lj); genStmt(env, rewriter, ej, at + 1); endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); } else { @@ -1884,7 +1892,7 @@ // sparse input tensor in succession until an acylic // iteration graph results. for (OpOperand *t : env.op().getDpsInputOperands()) { - const TensorId tid = t->getOperandNumber(); + const TensorId tid = env.makeTensorId(t->getOperandNumber()); Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); if (!srcEnc || !computeIterationGraph(env, SortMask::kSparseOnly, t)) @@ -1905,8 +1913,8 @@ auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - rewriter.updateRootInPlace(env.op(), - [&]() { env.op()->setOperand(tid, convert); }); + rewriter.updateRootInPlace( + env.op(), [&]() { env.op()->setOperand(tid, convert); }); rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -98,7 +98,7 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *o) +TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, Operation *o) : kind(k), val(v), op(o) { switch (kind) { // Leaf. @@ -200,16 +200,6 @@ llvm_unreachable("unexpected kind"); } -LatPoint::LatPoint(const BitVector &bits, ExprId e) : bits(bits), exp(e) {} - -LatPoint::LatPoint(unsigned numTensors, unsigned numLoops, TensorId t, LoopId i, - ExprId e) - : bits(numLoops * numTensors, false), exp(e) { - assert(t < numTensors && i < numLoops); - const TensorLoopId b = numTensors * i + t; - bits.set(b); -} - Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, unsigned numFilterLoops) : outTensor(numInputOutputTensors - 1), @@ -232,61 +222,89 @@ // Lattice methods. //===----------------------------------------------------------------------===// -ExprId Merger::addExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, - Operation *op) { - const ExprId e = tensorExps.size(); - assert((k != TensorExp::Kind::kTensor || x < numTensors) && - (k != TensorExp::Kind::kLoopVar || x < numLoops)); - tensorExps.emplace_back(k, x, y, v, op); - return e; +ExprId Merger::addTensorExp(TensorId t) { + assert(isValidTensorId(t)); + const ExprId eNew(tensorExps.size()); + tensorExps.emplace_back(TensorExp::Kind::kTensor, t, detail::kInvalidId, Value(), nullptr); + return eNew; +} + +ExprId Merger::addLoopVarExp(LoopId i) { + assert(isValidLoopId(i)); + const ExprId eNew(tensorExps.size()); + tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, detail::kInvalidId, Value(), nullptr); + return eNew; +} + +ExprId Merger::addInvariantExp(Value v) { + const ExprId eNew(tensorExps.size()); + tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, detail::kInvalidId, v, nullptr); + return eNew; +} + +ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { + assert(k > TensorExp::Kind::kLoopVar); + const ExprId eNew(tensorExps.size()); + tensorExps.emplace_back(k, e0, e1, Value(), op); + return eNew; +} + +ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { + assert(k > TensorExp::Kind::kLoopVar); + const ExprId eNew(tensorExps.size()); + tensorExps.emplace_back(k, e, detail::kInvalidId, v, op); + return eNew; } LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { - assert(t < numTensors && i < numLoops); - const LatPointId p = latPoints.size(); - latPoints.emplace_back(numTensors, numLoops, t, i, e); - return p; + const LatPointId pNew(latPoints.size()); + const unsigned size = numLoops * numTensors; + const TensorLoopId b = makeTensorLoopId(t, i); + latPoints.emplace_back(size, e); + latPoints[pNew].bits.set(b); + return pNew; } LatPointId Merger::addLat(const BitVector &bits, ExprId e) { assert(bits.size() == numLoops * numTensors); - const LatPointId p = latPoints.size(); + const LatPointId pNew(latPoints.size()); latPoints.emplace_back(bits, e); - return p; + return pNew; } LatSetId Merger::addSet() { - const LatSetId s = latSets.size(); + const LatSetId sNew(latSets.size()); latSets.emplace_back(); - return s; + return sNew; } LatPointId Merger::conjLat(TensorExp::Kind kind, LatPointId p0, LatPointId p1, Operation *op) { - const LatPointId p = latPoints.size(); - BitVector bits(latPoints[p0].bits); - bits |= latPoints[p1].bits; - const ExprId e = - addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); + const LatPointId pNew(latPoints.size()); + const auto &point0 = lat(p0); + const auto &point1 = lat(p1); + BitVector bits(point0.bits); + bits |= point1.bits; + const ExprId e = addExp(kind, point0.exp, point1.exp, op); latPoints.emplace_back(bits, e); - return p; + return pNew; } LatSetId Merger::conjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { - const LatSetId s = addSet(); - for (const LatPointId p0 : latSets[s0]) - for (const LatPointId p1 : latSets[s1]) - latSets[s].push_back(conjLat(kind, p0, p1, op)); - return s; + const LatSetId sNew = addSet(); + auto &setNew = latSets[sNew]; + for (const LatPointId p0 : set(s0)) + for (const LatPointId p1 : set(s1)) + setNew.push_back(conjLat(kind, p0, p1, op)); + return sNew; } LatSetId Merger::disjSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, Operation *op) { - const LatSetId s = conjSet(kind, s0, s1, op); + const LatSetId sNew = conjSet(kind, s0, s1, op); // Followed by all in s0. - for (const LatPointId p : latSets[s0]) - latSets[s].push_back(p); + latSets[sNew].append(latSets[s0]); // Map binary 0-y to unary -y. // TODO: move this if-else logic into buildLattices if (kind == TensorExp::Kind::kSubF) @@ -296,9 +314,8 @@ else if (kind == TensorExp::Kind::kSubI) s1 = mapSet(TensorExp::Kind::kNegI, s1); // Followed by all in s1. - for (const LatPointId p : latSets[s1]) - latSets[s].push_back(p); - return s; + latSets[sNew].append(latSets[s1]); + return sNew; } LatSetId Merger::combiSet(TensorExp::Kind kind, LatSetId s0, LatSetId s1, @@ -306,48 +323,48 @@ TensorExp::Kind ltrans, Operation *opleft, bool includeRight, TensorExp::Kind rtrans, Operation *opright) { - const LatSetId s = conjSet(kind, s0, s1, orig); + const LatSetId sNew = conjSet(kind, s0, s1, orig); // Left Region. if (includeLeft) { if (opleft) s0 = mapSet(ltrans, s0, Value(), opleft); - for (const LatPointId p : latSets[s0]) - latSets[s].push_back(p); + latSets[sNew].append(latSets[s0]); } // Right Region. if (includeRight) { if (opright) s1 = mapSet(rtrans, s1, Value(), opright); - for (const LatPointId p : latSets[s1]) - latSets[s].push_back(p); + latSets[sNew].append(latSets[s1]); } - return s; + return sNew; } LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, Operation *op) { assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); - const LatSetId s = addSet(); - for (const LatPointId p : latSets[s0]) { - const ExprId e = addExp(kind, latPoints[p].exp, v, op); - latSets[s].push_back(addLat(latPoints[p].bits, e)); + const LatSetId sNew = addSet(); + auto &setNew = latSets[sNew]; + for (const LatPointId p : set(s0)) { + const auto &point = latPoints[p]; + setNew.push_back(addLat(point.bits, addExp(kind, point.exp, v, op))); } - return s; + return sNew; } LatSetId Merger::optimizeSet(LatSetId s0) { - const LatSetId s = addSet(); - assert(!latSets[s0].empty()); - const LatPointId p0 = latSets[s0][0]; - for (const LatPointId p1 : latSets[s0]) { + const LatSetId sNew = addSet(); + auto &setNew = latSets[sNew]; + const auto &set0 = set(s0); + assert(!set0.empty()); + const LatPointId p0 = set0[0]; + for (const LatPointId p1 : set0) { bool add = true; if (p0 != p1) { // Check whether this is a straightforward copy. - const ExprId e = latPoints[p1].exp; - if (expIsTensor(e, outTensor)) + if (expIsTensor(latPoints[p1].exp, outTensor)) continue; // Check whether this conjunction is already covered. - for (const LatPointId p2 : latSets[s]) { + for (const LatPointId p2 : setNew) { assert(!latGT(p1, p2)); // Lj => Li would be bad if (onlyDenseDiff(p2, p1)) { add = false; @@ -357,34 +374,38 @@ assert(!add || latGT(p0, p1)); } if (add) - latSets[s].push_back(p1); + setNew.push_back(p1); } - for (const LatPointId p : latSets[s]) - latPoints[p].simple = simplifyCond(s, p); - return s; + for (const LatPointId p : setNew) + latPoints[p].simple = simplifyCond(sNew, p); + return sNew; } BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { // First determine if this lattice point is a *singleton*, i.e., // the last point in a lattice, no other is less than this one. bool isSingleton = true; - for (const LatPointId p1 : latSets[s0]) { + for (const LatPointId p1 : set(s0)) { if (p0 != p1 && latGT(p0, p1)) { isSingleton = false; break; } } - BitVector simple(latPoints[p0].bits); + BitVector simple(lat(p0).bits); bool reset = isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple)); - const TensorLoopId be = simple.size(); - TensorLoopId offset = 0; // relative to the end + // `be`, `b`, and `offset` are `TensorLoopId` in spirit; but we avoid + // using that class in this function because we need to do a bunch of + // arithmetic on them, so using the newtype would introduce too much + // boilerplate. + const unsigned be = simple.size(); + unsigned offset = 0; // relative to the end if (!reset) // Starts resetting from a dense level, so that the first bit (if kept) // is not undefined level-type. - for (TensorLoopId b = 0; b < be; b++) { - if (simple[b] && isDenseDLT(getDimLevelType(b))) { + for (unsigned b = 0; b < be; b++) { + if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) { offset = be - b - 1; // relative to the end break; } @@ -392,12 +413,12 @@ // Now apply the two basic rules. We also iterate the bits reversely to always // keep the rightmost bit (which could possibly be a synthetic tensor). - for (TensorLoopId b = be - 1 - offset, i = 0; i < be; + for (unsigned b = be - 1 - offset, i = 0; i < be; b = b == 0 ? be - 1 : b - 1, i++) { // FIXME: better name? also slice on dense level has locate property as // well. Handle it correctly! - if (simple[b] && !isLvlWithNonTrivialIdxExp(b)) { - const auto dlt = getDimLevelType(b); + if (simple[b] && !isLvlWithNonTrivialIdxExp(TensorLoopId{b})) { + const auto dlt = getDimLevelType(TensorLoopId{b}); if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { if (reset) simple.reset(b); @@ -409,8 +430,8 @@ } bool Merger::latGT(LatPointId i, LatPointId j) const { - const BitVector &bitsi = latPoints[i].bits; - const BitVector &bitsj = latPoints[j].bits; + const BitVector &bitsi = lat(i).bits; + const BitVector &bitsj = lat(j).bits; assert(bitsi.size() == bitsj.size()); if (bitsi.count() > bitsj.count()) { for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) @@ -422,27 +443,28 @@ } bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { - BitVector tmp(latPoints[j].bits); - tmp ^= latPoints[i].bits; + BitVector tmp(lat(j).bits); + tmp ^= lat(i).bits; return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp); } bool Merger::expContainsTensor(ExprId e, TensorId t) const { - if (tensorExps[e].kind == TensorExp::Kind::kTensor) - return tensorExps[e].tensor == t; + const auto &expr = exp(e); + if (expr.kind == TensorExp::Kind::kTensor) + return expr.tensor == t; - switch (getExpArity(tensorExps[e].kind)) { + switch (getExpArity(expr.kind)) { case ExpArity::kNullary: return false; case ExpArity::kUnary: { - const ExprId e0 = tensorExps[e].children.e0; + const ExprId e0 = expr.children.e0; if (expIsTensor(e0, t)) return true; return expContainsTensor(e0, t); } case ExpArity::kBinary: { - const ExprId e0 = tensorExps[e].children.e0; - const ExprId e1 = tensorExps[e].children.e1; + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; if (expIsTensor(e0, t) || expIsTensor(e1, t)) return true; return expContainsTensor(e0, t) || expContainsTensor(e1, t); @@ -452,25 +474,26 @@ } bool Merger::hasNegateOnOut(ExprId e) const { - switch (tensorExps[e].kind) { + const auto &expr = exp(e); + switch (expr.kind) { case TensorExp::Kind::kNegF: case TensorExp::Kind::kNegC: case TensorExp::Kind::kNegI: - return expContainsTensor(tensorExps[e].children.e0, outTensor); + return expContainsTensor(expr.children.e0, outTensor); case TensorExp::Kind::kSubF: case TensorExp::Kind::kSubC: case TensorExp::Kind::kSubI: - return expContainsTensor(tensorExps[e].children.e1, outTensor) || - hasNegateOnOut(tensorExps[e].children.e0); + return expContainsTensor(expr.children.e1, outTensor) || + hasNegateOnOut(expr.children.e0); default: { - switch (getExpArity(tensorExps[e].kind)) { + switch (getExpArity(expr.kind)) { case ExpArity::kNullary: return false; case ExpArity::kUnary: - return hasNegateOnOut(tensorExps[e].children.e0); + return hasNegateOnOut(expr.children.e0); case ExpArity::kBinary: - return hasNegateOnOut(tensorExps[e].children.e0) || - hasNegateOnOut(tensorExps[e].children.e1); + return hasNegateOnOut(expr.children.e0) || + hasNegateOnOut(expr.children.e1); } } } @@ -478,11 +501,12 @@ } bool Merger::isSingleCondition(TensorId t, ExprId e) const { - assert(t < numTensors && e < tensorExps.size()); - switch (tensorExps[e].kind) { + assert(isValidTensorId(t)); + const auto &expr = exp(e); + switch (expr.kind) { // Leaf. case TensorExp::Kind::kTensor: - return tensorExps[e].tensor == t; + return expr.tensor == t; case TensorExp::Kind::kInvariant: case TensorExp::Kind::kLoopVar: return false; @@ -518,7 +542,7 @@ case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: case TensorExp::Kind::kBitCast: - return isSingleCondition(t, tensorExps[e].children.e0); + return isSingleCondition(t, expr.children.e0); case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kUnary: case TensorExp::Kind::kSelect: @@ -528,28 +552,28 @@ case TensorExp::Kind::kDivC: case TensorExp::Kind::kDivS: case TensorExp::Kind::kDivU: - assert(!maybeZero(tensorExps[e].children.e1)); - return isSingleCondition(t, tensorExps[e].children.e0); + assert(!maybeZero(expr.children.e1)); + return isSingleCondition(t, expr.children.e0); case TensorExp::Kind::kShrS: // note: x >> inv only case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(isInvariant(tensorExps[e].children.e1)); - return isSingleCondition(t, tensorExps[e].children.e0); + assert(isInvariant(expr.children.e1)); + return isSingleCondition(t, expr.children.e0); case TensorExp::Kind::kMulF: case TensorExp::Kind::kMulC: case TensorExp::Kind::kMulI: case TensorExp::Kind::kAndI: - if (isSingleCondition(t, tensorExps[e].children.e0)) - return isSingleCondition(t, tensorExps[e].children.e1) || - isInvariant(tensorExps[e].children.e1); - if (isSingleCondition(t, tensorExps[e].children.e1)) - return isInvariant(tensorExps[e].children.e0); + if (isSingleCondition(t, expr.children.e0)) + return isSingleCondition(t, expr.children.e1) || + isInvariant(expr.children.e1); + if (isSingleCondition(t, expr.children.e1)) + return isInvariant(expr.children.e0); return false; case TensorExp::Kind::kAddF: case TensorExp::Kind::kAddC: case TensorExp::Kind::kAddI: - return isSingleCondition(t, tensorExps[e].children.e0) && - isSingleCondition(t, tensorExps[e].children.e1); + return isSingleCondition(t, expr.children.e0) && + isSingleCondition(t, expr.children.e1); case TensorExp::Kind::kSubF: case TensorExp::Kind::kSubC: case TensorExp::Kind::kSubI: @@ -684,20 +708,21 @@ } void Merger::dumpExp(ExprId e) const { - switch (tensorExps[e].kind) { + const auto &expr = exp(e); + switch (expr.kind) { // Leaf. case TensorExp::Kind::kTensor: - if (tensorExps[e].tensor == syntheticTensor) + if (expr.tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; - else if (tensorExps[e].tensor == outTensor) + else if (expr.tensor == outTensor) llvm::dbgs() << "output_"; - llvm::dbgs() << "tensor_" << tensorExps[e].tensor; + llvm::dbgs() << "tensor_" << expr.tensor; break; case TensorExp::Kind::kInvariant: llvm::dbgs() << "invariant"; break; case TensorExp::Kind::kLoopVar: - llvm::dbgs() << "loopvar_" << tensorExps[e].loop; + llvm::dbgs() << "loopvar_" << expr.loop; break; // Unary operations. case TensorExp::Kind::kAbsF: @@ -734,8 +759,8 @@ case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kUnary: case TensorExp::Kind::kSelect: - llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; - dumpExp(tensorExps[e].children.e0); + llvm::dbgs() << kindToOpSymbol(expr.kind) << " "; + dumpExp(expr.children.e0); break; // Binary operations. case TensorExp::Kind::kMulF: @@ -760,26 +785,28 @@ case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: llvm::dbgs() << "("; - dumpExp(tensorExps[e].children.e0); - llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; - dumpExp(tensorExps[e].children.e1); + dumpExp(expr.children.e0); + llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " "; + dumpExp(expr.children.e1); llvm::dbgs() << ")"; } } void Merger::dumpLat(LatPointId p) const { + const auto &point = lat(p); llvm::dbgs() << "lat("; - dumpBits(latPoints[p].bits); + dumpBits(point.bits); llvm::dbgs() << " :"; - dumpBits(latPoints[p].simple); + dumpBits(point.simple); llvm::dbgs() << " : "; - dumpExp(latPoints[p].exp); + dumpExp(point.exp); llvm::dbgs() << " )\n"; } void Merger::dumpSet(LatSetId s) const { - llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; - for (const LatPointId p : latSets[s]) { + const auto &ss = set(s); + llvm::dbgs() << "{ #" << ss.size() << "\n"; + for (const LatPointId p : ss) { llvm::dbgs() << " "; dumpLat(p); } @@ -807,7 +834,12 @@ //===----------------------------------------------------------------------===// LatSetId Merger::buildLattices(ExprId e, LoopId i) { - const TensorExp::Kind kind = tensorExps[e].kind; + // NOTE: The `expr` reference will be invalidated by recursive calls + // (and any other method that may add new expressions); therefore, the + // code below must make sure to copy fields of `expr` into local variables + // before making any recursive calls. + const auto &expr = exp(e); + const TensorExp::Kind kind = expr.kind; switch (kind) { // Leaf. case TensorExp::Kind::kTensor: @@ -821,7 +853,7 @@ const LatSetId s = addSet(); TensorId t = syntheticTensor; if (kind == TensorExp::Kind::kTensor) { - t = tensorExps[e].tensor; + t = expr.tensor; if (hasSparseOut && t == outTensor) t = syntheticTensor; } @@ -866,14 +898,20 @@ // -y|!y | y | // --+---+---+ // | 0 |-y | - return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), - tensorExps[e].val); + { + const ExprId e0 = expr.children.e0; + const Value v = expr.val; + return mapSet(kind, buildLattices(e0, i), v); + } case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: // The left or right half of a binary operation which has already // been split into separate operations for each region. - return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), - tensorExps[e].op); + { + const ExprId e0 = expr.children.e0; + Operation *const op = expr.op; + return mapSet(kind, buildLattices(e0, i), Value(), op); + } case TensorExp::Kind::kUnary: // A custom unary operation. // @@ -881,8 +919,9 @@ // ----+----------+------------+ // | absent() | present(y) | { - const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i); - UnaryOp unop = cast(tensorExps[e].op); + const ExprId e0 = expr.children.e0; + UnaryOp unop = cast(expr.op); + const LatSetId child0 = buildLattices(e0, i); Region &absentRegion = unop.getAbsentRegion(); if (absentRegion.empty()) { @@ -892,8 +931,8 @@ // invariant on the right. Block &absentBlock = absentRegion.front(); YieldOp absentYield = cast(absentBlock.getTerminator()); - Value absentVal = absentYield.getResult(); - const ExprId rhs = addExp(TensorExp::Kind::kInvariant, absentVal); + const Value absentVal = absentYield.getResult(); + const ExprId rhs = addInvariantExp(absentVal); return disjSet(kind, child0, buildLattices(rhs, i), unop); } // Binary operations. @@ -910,8 +949,11 @@ // x | 0 |x*y| // // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. - return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + { + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); + } case TensorExp::Kind::kDivF: case TensorExp::Kind::kDivC: case TensorExp::Kind::kDivS: @@ -929,9 +971,12 @@ // during expression building, so that the conjunction // rules applies (viz. x/c = x*(1/c) as far as lattice // construction is concerned). - assert(!maybeZero(tensorExps[e].children.e1)); - return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + { + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + assert(!maybeZero(e1)); + return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); + } case TensorExp::Kind::kAddF: case TensorExp::Kind::kAddC: case TensorExp::Kind::kAddI: @@ -947,17 +992,23 @@ // ---+---+---+ ---+---+---+ // !x | 0 | y | !x | 0 |-y | // x | x |x+y| x | x |x-y| - return disjSet(kind, buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + { + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + return disjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); + } case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: // A shift operation by an invariant amount (viz. tensor expressions // can only occur at the left-hand-side of the operator) can be handled // with the conjuction rule. - assert(isInvariant(tensorExps[e].children.e1)); - return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i)); + { + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + assert(isInvariant(e1)); + return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i)); + } case TensorExp::Kind::kBinary: // A custom binary operation. // @@ -966,9 +1017,11 @@ // !x | empty | right(y) | // x | left(x) | overlap(x,y) | { - const LatSetId child0 = buildLattices(tensorExps[e].children.e0, i); - const LatSetId child1 = buildLattices(tensorExps[e].children.e1, i); - BinaryOp binop = cast(tensorExps[e].op); + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + BinaryOp binop = cast(expr.op); + const LatSetId child0 = buildLattices(e0, i); + const LatSetId child1 = buildLattices(e1, i); Region &leftRegion = binop.getLeftRegion(); Region &rightRegion = binop.getRightRegion(); // Left Region. @@ -991,9 +1044,12 @@ } case TensorExp::Kind::kReduce: // A custom reduce operation. - return conjSet(kind, buildLattices(tensorExps[e].children.e0, i), - buildLattices(tensorExps[e].children.e1, i), - tensorExps[e].op); + { + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + Operation *const op = expr.op; + return conjSet(kind, buildLattices(e0, i), buildLattices(e1, i), op); + } } llvm_unreachable("unexpected expression kind"); } @@ -1007,27 +1063,24 @@ /// Only returns false if we are certain this is a nonzero. bool Merger::maybeZero(ExprId e) const { - if (tensorExps[e].kind == TensorExp::Kind::kInvariant) { - if (auto c = tensorExps[e].val.getDefiningOp()) { + const auto &expr = exp(e); + if (expr.kind == TensorExp::Kind::kInvariant) { + if (auto c = expr.val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); return arrayAttr[0].cast().getValue().isZero() && arrayAttr[1].cast().getValue().isZero(); } - if (auto c = tensorExps[e].val.getDefiningOp()) + if (auto c = expr.val.getDefiningOp()) return c.value() == 0; - if (auto c = tensorExps[e].val.getDefiningOp()) + if (auto c = expr.val.getDefiningOp()) return c.value().isZero(); } return true; } -bool Merger::isInvariant(ExprId e) const { - return tensorExps[e].kind == TensorExp::Kind::kInvariant; -} - Type Merger::inferType(ExprId e, Value src) const { // Obtain the destination type from the cast node. - Type dtp = tensorExps[e].val.getType(); + Type dtp = exp(e).val.getType(); // Inspect source type. For vector types, apply the same // vectorization to the destination type. if (auto vtp = src.getType().dyn_cast()) @@ -1067,28 +1120,28 @@ std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { - const TensorId argN = arg.getArgNumber(); + const TensorId tid = makeTensorId(arg.getArgNumber()); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { - OpOperand &t = op->getOpOperand(argN); + OpOperand &t = op->getOpOperand(tid); if (!op.isScalar(&t)) - return addExp(TensorExp::Kind::kTensor, argN); + return addTensorExp(tid); v = t.get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. - return addExp(TensorExp::Kind::kInvariant, v); + return addInvariantExp(v); } // Something defined outside is invariant. Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.getRegion().front()) - return addExp(TensorExp::Kind::kInvariant, v); + return addInvariantExp(v); // Construct index operations. if (def->getNumOperands() == 0) { if (auto indexOp = dyn_cast(def)) - return addExp(TensorExp::Kind::kLoopVar, indexOp.getDim()); + return addLoopVarExp(makeLoopId(indexOp.getDim())); } // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { @@ -1219,7 +1272,7 @@ isAdmissibleBranch(binop, binop.getLeftRegion())) && (binop.getRightIdentity() || isAdmissibleBranch(binop, binop.getRightRegion()))) - return addExp(TensorExp::Kind::kBinary, e0, e1, Value(), def); + return addExp(TensorExp::Kind::kBinary, e0, e1, def); } } } @@ -1233,7 +1286,7 @@ const ExprId e1 = *y; if (auto redop = dyn_cast(def)) { if (isAdmissibleBranch(redop, redop.getRegion())) - return addExp(TensorExp::Kind::kReduce, e0, e1, Value(), def); + return addExp(TensorExp::Kind::kReduce, e0, e1, def); } } } @@ -1288,7 +1341,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, Value v1) const { - switch (tensorExps[e].kind) { + const auto &expr = exp(e); + switch (expr.kind) { // Leaf. case TensorExp::Kind::kTensor: case TensorExp::Kind::kInvariant: @@ -1410,17 +1464,17 @@ case TensorExp::Kind::kShlI: return rewriter.create(loc, v0, v1); case TensorExp::Kind::kBinaryBranch: // semi-ring ops with custom logic. - return insertYieldOp(rewriter, loc, - *tensorExps[e].op->getBlock()->getParent(), {v0}); + return insertYieldOp(rewriter, loc, *expr.op->getBlock()->getParent(), + {v0}); case TensorExp::Kind::kUnary: - return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); + return buildUnaryPresent(rewriter, loc, expr.op, v0); case TensorExp::Kind::kSelect: - return insertYieldOp(rewriter, loc, - cast(tensorExps[e].op).getRegion(), {v0}); + return insertYieldOp(rewriter, loc, cast(expr.op).getRegion(), + {v0}); case TensorExp::Kind::kBinary: - return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); + return buildBinaryOverlap(rewriter, loc, expr.op, v0, v1); case TensorExp::Kind::kReduce: { - ReduceOp redOp = cast(tensorExps[e].op); + ReduceOp redOp = cast(expr.op); return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); } } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -139,21 +139,15 @@ : merger(numTensors, numLoops, /*numFilterLoops=*/0) { tensors.reserve(numTensors); for (unsigned t = 0; t < numTensors; t++) - tensors.push_back(merger.addExp(TensorExp::Kind::kTensor, tid(t))); + tensors.push_back(merger.addTensorExp(tid(t))); } /// /// Expression construction helpers. /// - TensorId tid(unsigned t) const { - assert(t < merger.getNumTensors()); - return t; - } - LoopId lid(unsigned i) const { - assert(i < merger.getNumLoops()); - return i; - } + TensorId tid(unsigned t) const { return merger.makeTensorId(t); } + LoopId lid(unsigned i) const { return merger.makeLoopId(i); } ExprId tensor(unsigned t) const { assert(t < tensors.size()); return tensors[t]; @@ -207,11 +201,9 @@ /// Converts a vector of (loop, tensor) pairs to a bitvector with the /// corresponding bits set. BitVector loopsToBits(const std::vector> &loops) { - // NOTE: this `numTensors` includes both the output- and synthetic-tensors. - const auto numTensors = merger.getNumTensors(); - BitVector testBits = BitVector(numTensors, false); + BitVector testBits = BitVector(merger.getNumTensors(), false); for (auto [loop, tensor] : loops) - testBits.set(numTensors * loop + tensor); + testBits.set(merger.makeTensorLoopId(tensor, loop)); return testBits; }