diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -31,21 +31,21 @@ /// Child subexpressions for non-leaf expressions. struct Children final { + explicit constexpr Children() = default; + explicit constexpr Children(ExprId e0, ExprId e1) : e0(e0), e1(e1) {} ExprId e0; ExprId e1; }; - // The `x` parameter has different types depending on the value of the - // `k` parameter. The correspondences are: - // * `kTensor` -> `TensorId` - // * `kInvariant` -> `kInvalidId` - // * `kLoopVar` -> `LoopId` - // * else -> `ExprId` - // - // The `y`, `v`, and `op` parameters either must or must not be - // `kInvalidId`/`nullptr`, depending on the value of the `k` parameter; - // however, they have uniform C++ types regardless of the value of `k`. - TensorExp(Kind k, unsigned x, ExprId y, Value v, Operation *op); + explicit constexpr TensorExp(TensorId t); + explicit constexpr TensorExp(LoopId i); + explicit constexpr TensorExp(Value v); + + /// This constructor handles all unary, sesquinary, and binary + /// expressions. The `e0` argument must always be a valid identifier. + /// Whereas, the other three arguments either must or must not be valid, + /// depending on the value of `k`. + constexpr TensorExp(Kind k, ExprId e0, ExprId e1, Value v, Operation *op); /// Tensor expression kind. Kind kind; @@ -217,20 +217,20 @@ /// Safely converts the argument to a tensor identifier. constexpr TensorId makeTensorId(unsigned t) const { - assert(isValidTensorId(t)); - return t; + return factory.makeTensorId(t); } /// Safely converts the argument to a loop identifier. constexpr LoopId makeLoopId(unsigned i) const { - assert(isValidLoopId(i)); - return i; + return factory.makeLoopId(i); } /// Safely converts the arguments to a pair of (tensor,loop) identifiers. constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { - assert(isValidTensorId(t) && isValidLoopId(i)); - return numTensors * i + t; + return factory.makeTensorLoopId(t, i); + } + constexpr TensorLoopId makeTensorLoopId(TensorId t, LoopId i) const { + return factory.makeTensorLoopId(t, i); } // @@ -307,37 +307,35 @@ bool onlyDenseDiff(LatPointId p0, LatPointId p1) const; /// Gets the tensor-identifier of the `TensorLoopId`. - constexpr TensorId tensor(TensorLoopId b) const { return b % numTensors; } + constexpr TensorId tensor(TensorLoopId b) const { + return factory.getTensorId(b); + } /// Gets the loop-identifier of the `TensorLoopId`. - constexpr LoopId loop(TensorLoopId b) const { return b / numTensors; } + constexpr LoopId loop(TensorLoopId b) const { return factory.getLoopId(b); } /// Get the total number of tensors (including the output-tensor and /// synthetic-tensor). - constexpr unsigned getNumTensors() const { return numTensors; } + constexpr unsigned getNumTensors() const { return factory.getNumTensors(); } /// Get the range of all tensor identifiers. - constexpr tensor_id::Range getTensorIds() const { - return tensor_id::Range(0, numTensors); - } + constexpr TensorId::Range getTensorIds() const { return factory.tensorIds(); } /// Get the total number of loops (native loops + filter loops). - constexpr unsigned getNumLoops() const { return numLoops; } + constexpr unsigned getNumLoops() const { return factory.getNumLoops(); } /// Get the number of native loops. constexpr unsigned getNumNativeLoops() const { return numNativeLoops; } /// Get the number of filter loops. constexpr unsigned getNumFilterLoops() const { - return numLoops - numNativeLoops; + return getNumLoops() - numNativeLoops; } /// Get the range of all loop identifiers. - constexpr loop_id::Range getLoopIds() const { - return loop_id::Range(0, numLoops); - } + constexpr LoopId::Range getLoopIds() const { return factory.loopIds(); } /// Get the range of native-loop identifiers. - constexpr loop_id::Range getNativeLoopIds() const { - return loop_id::Range(0, numNativeLoops); + constexpr LoopId::Range getNativeLoopIds() const { + return factory.loopIdsUpTo(numNativeLoops); } /// Get the range of filter-loop identifiers. - constexpr loop_id::Range getFilterLoopIds() const { - return loop_id::Range(numNativeLoops, numLoops); + constexpr LoopId::Range getFilterLoopIds() const { + return factory.loopIdsFrom(numNativeLoops); } /// Returns true if `b` is the `i`th loop of the output tensor. @@ -353,7 +351,7 @@ constexpr bool isFilterLoop(LoopId i) const { assert(isValidLoopId(i)); - return i >= numNativeLoops; + return i.value >= numNativeLoops; } /// Returns true if the expression is `(kTensor t)`. @@ -388,7 +386,7 @@ /// Gets the level-type of the `t`th tensor on `i`th loop. DimLevelType getDimLevelType(TensorId t, LoopId i) const { assert(isValidTensorId(t) && isValidLoopId(i)); - return lvlTypes[t][i]; + return lvlTypes[t.value][i.value]; } /// Gets the level-type of the TensorLoopId. @@ -399,13 +397,13 @@ /// Gets the loop identifier for the `lvl`th level of the `t`th tensor. std::optional getLoopId(TensorId t, Level lvl) const { assert(isValidLevel(t, lvl)); - return lvlToLoop[t][lvl]; + return lvlToLoop[t.value][lvl]; } /// Gets the level number of the the `t`th tensor on `i`th loop. std::optional getLvl(TensorId t, LoopId i) const { assert(isValidTensorId(t) && isValidLoopId(i)); - return loopToLvl[t][i]; + return loopToLvl[t.value][i.value]; } std::optional getLvl(TensorLoopId b) const { return getLvl(tensor(b), loop(b)); @@ -415,9 +413,9 @@ /// `i`th loop. void setLevelAndType(TensorId t, LoopId i, Level lvl, DimLevelType dlt) { assert(isValidLevel(t, lvl) && isValidLoopId(i) && isValidDLT(dlt)); - lvlTypes[t][i] = dlt; - loopToLvl[t][i] = lvl; - lvlToLoop[t][lvl] = i; + lvlTypes[t.value][i.value] = dlt; + loopToLvl[t.value][i.value] = lvl; + lvlToLoop[t.value][lvl] = i; } using ForeachTensorLoopIdCallback = function_ref . void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl) { assert(isValidLoopId(i) && isValidLevel(t, lvl)); - loopToDependencies[i][t] = lvl; - levelToDependentIdx[t][lvl].push_back(i); + loopToDependencies[i.value][t.value] = lvl; + levelToDependentIdx[t.value][lvl].push_back(i); } /// Whether the loop has dependent slice. bool hasDependentLvl(LoopId i, TensorId t) { assert(isValidTensorId(t) && isValidLoopId(i)); - return loopToDependencies[i][t].has_value(); + return loopToDependencies[i.value][t.value].has_value(); } /// Returns the list of loop indices which appear in the non-trivial index /// expression on t_l, e.g., A[i+j] => {i, j} std::vector &getDependentLoops(TensorId t, Level lvl) { assert(isValidLevel(t, lvl)); - return levelToDependentIdx[t][lvl]; + return levelToDependentIdx[t.value][lvl]; } /// Returns the defining [tid, lvl] for the loop. std::pair getLoopDefiningLvl(LoopId i) const { assert(isValidLoopId(i)); - return loopBounds[i]; + return loopBounds[i.value]; } /// Checks whether the TensorLoopId represents a tensor level with @@ -488,7 +486,7 @@ const TensorId t = tensor(b); const LoopId i = loop(b); assert(isValidTensorId(t) && isValidLoopId(i)); - return loopToDependencies[i][t].has_value(); + return loopToDependencies[i.value][t.value].has_value(); } /// Convenience getters to immediately access the stored nodes. @@ -582,12 +580,14 @@ private: /// Private helpers. - constexpr bool isValidTensorId(TensorId t) const { return t < numTensors; } + constexpr bool isValidTensorId(TensorId t) const { + return factory.isValidTensorId(t); + } constexpr bool isValidLoopId(LoopId i) const { - return i != detail::kInvalidId && i < numLoops; + return factory.isValidLoopId(i); } bool isValidLevel(TensorId t, Level lvl) const { - return isValidTensorId(t) && lvl < lvlToLoop[t].size(); + return isValidTensorId(t) && lvl < lvlToLoop[t.value].size(); } bool isValidExprId(ExprId e) const { return e.isValid() && e.value < tensorExps.size(); @@ -610,9 +610,8 @@ /// Merger data structures. const TensorId outTensor; const TensorId syntheticTensor; - const unsigned numTensors; + const TensorLoopFactory factory; const unsigned numNativeLoops; - const unsigned numLoops; bool hasSparseOut; // Below we use `std::vector` for things which have a priori fixed diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h @@ -45,7 +45,13 @@ } // namespace detail //===----------------------------------------------------------------------===// -/// Tensor identifiers. +/// Tensor identifiers. The `TensorLoopFactory` class can be used to +/// construct identifiers from `unsigned` values while checking validity. +/// (For more details, see that class.) Since client code often needs +/// to construct tensor identifiers from arbitrary values, we provide +/// the `TensorId::Unchecked` factory-method for constructing arbitrary +/// identifiers. Unlike the other newtypes, we don't have a dedicated +/// known-invalid tensor identifier (because it isn't needed so far). /// /// Semantically, tensor identifiers could be chosen to be anything; /// but operationally, they must be chosen such that the `Merger` @@ -54,31 +60,61 @@ /// of the value passed to `Merger::buildTensorExp`, which ranges from /// zero to `linalg::GenericOp::getNumOperands` for the op passed to /// `GenericOpSparsifier::matchAndRewrite`. -using TensorId = unsigned; - -// NOTE: We use this namespace to simulate having turned `TensorId` -// into a newtype, so that we can split the patch for adding the iterators -// from the patch for actually making it a newtype. -namespace tensor_id { -class Iterator; -class Range; -} // namespace tensor_id +class TensorId final { +public: + // Must forward-declare, because they can only be defined once the + // `TensorId` definition is complete. + class Iterator; + class Range; + +private: + friend class TensorId::Iterator; + friend class TensorLoopFactory; + friend class Merger; + explicit constexpr TensorId(unsigned value) : value(value) {} + +public: + /// Constructs a new tensor identifier without performing any validity + /// checks. We provide this as a factory rather than as a ctor so that + /// the potential invalidity is explicit at all callsites. + static constexpr TensorId Unchecked(unsigned value) { + return TensorId{value}; + } + + // Since the set of valid identifiers and their numeric values are both + // fixed at the time the `Merger` (or `TensorLoopFactory`) is constructed, + // it's perfectly meaningful to allow client code to extract those values. + constexpr unsigned operator*() const { return value; } + + constexpr bool operator==(TensorId rhs) const { return value == rhs.value; } + constexpr bool operator!=(TensorId rhs) const { return value != rhs.value; } + + /// Returns the tensor identifiers `[0, hi)` as an iterable range. + static constexpr Range fromZeroUpTo(TensorId hi); + /// Returns the tensor identifiers `[this, hi)` as an iterable range. + constexpr Range upTo(TensorId hi) const; + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); /// An iterator for `TensorId`. We define this as a separate class because /// it wouldn't be generally safe/meaningful to define `TensorId::operator++`. /// The ctor is private for similar reasons, so client code should create -/// iterators via `tensor_id::Range` instead. -class tensor_id::Iterator final +/// iterators via `TensorId::Range` instead. +class TensorId::Iterator final : public llvm::iterator_facade_base { - friend class tensor_id::Range; + friend class TensorId::Range; explicit constexpr Iterator(TensorId tid) : tid(tid) {} public: using llvm::iterator_facade_base::operator++; Iterator &operator++() { - ++tid; + ++tid.value; return *this; } const TensorId *operator->() const { return &tid; } @@ -89,14 +125,14 @@ private: TensorId tid; }; -static_assert(std::is_trivially_copyable_v && - std::is_trivially_destructible_v); +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); /// An iterator range for `TensorId`. -class tensor_id::Range final { +class TensorId::Range final { public: explicit constexpr Range(TensorId lo, TensorId hi) - : begin_(lo <= hi ? lo : hi), end_(hi) {} + : begin_(*lo <= *hi ? lo : hi), end_(hi) {} constexpr Iterator begin() const { return begin_; } constexpr Iterator end() const { return end_; } constexpr bool empty() const { return begin_ == end_; } @@ -105,11 +141,26 @@ Iterator begin_; Iterator end_; }; -static_assert(std::is_trivially_copyable_v && - std::is_trivially_destructible_v); +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +// These two can only be defined once the definitions of both `TensorId` +// and `TensorId::Range` are complete. +inline constexpr TensorId::Range TensorId::fromZeroUpTo(TensorId hi) { + return TensorId{0}.upTo(hi); +} +inline constexpr TensorId::Range TensorId::upTo(TensorId hi) const { + return TensorId::Range(*this, hi); +} //===----------------------------------------------------------------------===// -/// Loop identifiers. +/// Loop identifiers. The `TensorLoopFactory` class can be used to +/// construct identifiers from `unsigned` values while checking validity. +/// (For more details, see that class.) Since client code often needs +/// to construct loop identifiers from arbitrary values, we provide +/// the `LoopId::Unchecked` factory-method for constructing arbitrary +/// identifiers. In addition, the `LoopId::Invalid` factory-method can +/// be used to construct a dedicated known-invalid loop identifier. /// /// These identifiers serve as proxies for the `$dim` argument to /// `linalg::IndexOp`, however the numerical value of a `LoopId` should @@ -131,31 +182,65 @@ /// fact that its numerical value is not invariant when entering/exiting /// loops (unlike `TensorId`, `ExprId`, `LatPointId`, and `LatSetId` which /// are invariant identifiers). -using LoopId = unsigned; - -// NOTE: We use this namespace to simulate having turned `LoopId` into -// a newtype, so that we can split the patch for adding the iterators from -// the patch for actually making it a newtype. -namespace loop_id { -class Iterator; -class Range; -} // namespace loop_id +class LoopId final { +public: + // Must forward-declare, because they can only be defined once the + // `LoopId` definition is complete. + class Iterator; + class Range; + +private: + friend class LoopId::Iterator; + friend class TensorLoopFactory; + friend class Merger; + explicit constexpr LoopId(unsigned value) : value(value) {} + +public: + /// Constructs a new loop identifier without performing any validity + /// checks. We provide this as a factory rather than as a ctor so that + /// the potential invalidity is explicit at all callsites. + static constexpr LoopId Unchecked(unsigned value) { return LoopId{value}; } + + /// Constructs a new loop identifier with a dedicated known-invalid value. + static constexpr LoopId Invalid() { return LoopId{detail::kInvalidId}; } + + /// Checks whether the loop identifier has the dedicated known-invalid value. + constexpr bool isValid() const { return value != detail::kInvalidId; } + + // Since the set of valid identifiers and their numeric values are both + // fixed at the time the `Merger` (or `TensorLoopFactory`) is constructed, + // it's perfectly meaningful to allow client code to extract those values. + constexpr unsigned operator*() const { return value; } + + constexpr bool operator==(LoopId rhs) const { return value == rhs.value; } + constexpr bool operator!=(LoopId rhs) const { return value != rhs.value; } + + /// Returns the loop identifiers `[0, hi)` as an iterable range. + static constexpr Range fromZeroUpTo(LoopId hi); + /// Returns the loop identifiers `[this, hi)` as an iterable range. + constexpr Range upTo(LoopId hi) const; + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); /// An iterator for `LoopId`. We define this as a separate class because /// it wouldn't be generally safe/meaningful to define `LoopId::operator++`. /// The ctor is private for similar reasons, so client code should create -/// iterators via `loop_id::Range` instead. -class loop_id::Iterator final +/// iterators via `LoopId::Range` instead. +class LoopId::Iterator final : public llvm::iterator_facade_base { - friend class loop_id::Range; + friend class LoopId::Range; explicit constexpr Iterator(LoopId i) : loop(i) {} public: using llvm::iterator_facade_base::operator++; Iterator &operator++() { - ++loop; + ++loop.value; return *this; } const LoopId *operator->() const { return &loop; } @@ -166,17 +251,16 @@ private: LoopId loop; }; -static_assert(std::is_trivially_copyable_v && - std::is_trivially_destructible_v); +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); /// An iterator range for `LoopId`. -class loop_id::Range final { +class LoopId::Range final { public: explicit constexpr Range(LoopId lo, LoopId hi) - : begin_((lo != detail::kInvalidId && hi != detail::kInvalidId) - ? (lo <= hi ? lo : hi) - : detail::kInvalidId), - end_(lo != detail::kInvalidId ? hi : detail::kInvalidId) {} + : begin_((lo.isValid() && hi.isValid()) ? (*lo <= *hi ? lo : hi) + : LoopId::Invalid()), + end_(lo.isValid() ? hi : LoopId::Invalid()) {} constexpr Iterator begin() const { return begin_; } constexpr Iterator end() const { return end_; } constexpr bool empty() const { return begin_ == end_; } @@ -185,26 +269,55 @@ Iterator begin_; Iterator end_; }; -static_assert(std::is_trivially_copyable_v && - std::is_trivially_destructible_v); +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +// These two can only be defined once the definitions of both `LoopId` +// and `LoopId::Range` are complete. +inline constexpr LoopId::Range LoopId::fromZeroUpTo(LoopId hi) { + return LoopId{0}.upTo(hi); +} +inline constexpr LoopId::Range LoopId::upTo(LoopId hi) const { + return LoopId::Range(*this, hi); +} //===----------------------------------------------------------------------===// /// A compressed representation of `std::pair`. /// The compression scheme is such that this also serves as an index /// into the bitvector stored in `LatPoint` (since that bitvector is /// just the implementation for a set of `TensorLoopId` values). -using TensorLoopId = unsigned; - -// NOTE: We use this namespace to simulate having turned `TensorLoopId` into -// a newtype, so that we can split the patch for adding the iterators from -// the patch for actually making it a newtype. -namespace tensor_loop_id { -class Iterator; -class Range; -} // namespace tensor_loop_id +class TensorLoopId final { +public: + // Must forward-declare, because they can only be defined once the + // `TensorId` definition is complete. + class Iterator; + class Range; + +private: + friend class TensorLoopId::Iterator; + friend class TensorLoopFactory; + friend class Merger; + explicit constexpr TensorLoopId(unsigned value) : value(value) {} + +public: + // We provide this getter because some code in Sparsification.cpp + // needs it for indexing into `BitVector`s. Unfortunately this leaks + // implementation details about the compression scheme; but so long as + // client code only interprets it as an arbitrary index into `BitVector` + // rather than any meaningful number, then that's still safe. + constexpr unsigned operator*() const { return value; } + + constexpr bool operator==(TensorLoopId b) const { return value == b.value; } + constexpr bool operator!=(TensorLoopId b) const { return value != b.value; } + +private: + unsigned value; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); /// An iterator of the `TensorLoopId`s which are included/set in a `BitVector`. -class tensor_loop_id::Iterator final +class TensorLoopId::Iterator final : public llvm::iterator_adaptor_base< Iterator, llvm::BitVector::const_set_bits_iterator, // Since `const_set_bits_iterator` doesn't define its own @@ -215,7 +328,7 @@ // random-access /*pointer=*/const unsigned *, /*reference=*/const unsigned &> { - friend class tensor_loop_id::Range; + friend class TensorLoopId::Range; explicit Iterator(llvm::BitVector::const_set_bits_iterator I) : Iterator::iterator_adaptor_base(I) {} @@ -225,7 +338,7 @@ /// An iterator range for the `TensorLoopId`s which are included/set /// in a `BitVector`. -class tensor_loop_id::Range final { +class TensorLoopId::Range final { public: explicit Range(const llvm::BitVector &bits) : set_bits(bits.set_bits()) {} Iterator begin() const { return Iterator{set_bits.begin()}; } @@ -235,6 +348,117 @@ llvm::iterator_range set_bits; }; +//===----------------------------------------------------------------------===// +/// A factory class for constructing `TensorId`, `LoopId`, and `TensorLoopId` +/// while ensuring their validity. Unfortunately, since we can't "brand" the +/// resulting identifiers as having come from a particular factory instance, +/// this construction-time validity checking does not ensure use-time validity. +/// (The only way to fix that would be to template the identifier classes +/// on `numTensors`/`numLoops`, which would introduce too much complexity.) +class TensorLoopFactory { +public: + constexpr TensorLoopFactory(unsigned numTensors, unsigned numLoops) + : numTensors(numTensors), numLoops(numLoops) {} + + /// Get the number of tensors for this factory. + constexpr unsigned getNumTensors() const { return numTensors; } + + /// Get the number of loops for this factory. + constexpr unsigned getNumLoops() const { return numLoops; } + + /// Checks whether the tensor identifier is valid for this factory. + constexpr bool isValidTensorId(TensorId t) const { + return t.value < numTensors; + } + + /// Checks whether the loop identifier is valid for this factory. + constexpr bool isValidLoopId(LoopId i) const { + return i.isValid() && i.value < numLoops; + } + + /// Projects the tensor-identifier out of the `TensorLoopId`. + /// NOTE: This method doesn't/can't check that the `TensorLoopId` + /// was constructed for the same `numTensors` as the factory. + constexpr TensorId getTensorId(TensorLoopId b) const { + return TensorId{b.value % numTensors}; + } + + /// Projects the loop-identifier out of the `TensorLoopId`. + /// NOTE: This method doesn't/can't check that the `TensorLoopId` + /// was constructed for the same `numTensors` as the factory. + constexpr LoopId getLoopId(TensorLoopId b) const { + return LoopId{b.value / numTensors}; + } + + /// Safely converts the argument to a tensor identifier. + constexpr TensorId makeTensorId(unsigned t) const { + const TensorId tid{t}; + assert(isValidTensorId(tid)); + return tid; + } + + /// Returns the iterable range `[0, numTensors)`. + constexpr TensorId::Range tensorIds() const { + return TensorId::fromZeroUpTo(TensorId{numTensors}); + } + /// Returns the iterable range `[lo, numTensors)`. + constexpr TensorId::Range tensorIdsFrom(TensorId lo) const { + // The `TensorId::Range` ctor will handle the `lo > numTensors` case. + return TensorId::Range(lo, TensorId{numTensors}); + } + constexpr TensorId::Range tensorIdsFrom(unsigned lo) const { + return tensorIdsFrom(TensorId{lo}); + } + /// Returns the iterable range `[0, min(hi,numTensors))`. + constexpr TensorId::Range tensorIdsUpTo(TensorId hi) const { + return TensorId::fromZeroUpTo( + hi.value <= numTensors ? hi : TensorId{numTensors}); + } + constexpr TensorId::Range tensorIdsUpTo(unsigned hi) const { + return tensorIdsUpTo(TensorId{hi}); + } + + /// Safely converts the argument to a loop identifier. + constexpr LoopId makeLoopId(unsigned i) const { + const LoopId lid{i}; + assert(isValidLoopId(lid)); + return lid; + } + + /// Returns the iterable range `[0, numLoops)`. + constexpr LoopId::Range loopIds() const { + return LoopId::fromZeroUpTo(LoopId{numLoops}); + } + /// Returns the iterable range `[lo, numLoops)`. + constexpr LoopId::Range loopIdsFrom(LoopId lo) const { + // The `LoopId::Range` ctor will handle the `lo > numLoops` case. + return LoopId::Range(lo, LoopId{numLoops}); + } + constexpr LoopId::Range loopIdsFrom(unsigned lo) const { + return loopIdsFrom(LoopId{lo}); + } + /// Returns the iterable range `[0, min(hi,numLoops))`. + constexpr LoopId::Range loopIdsUpTo(LoopId hi) const { + return LoopId::fromZeroUpTo(hi.value <= numLoops ? hi : LoopId{numLoops}); + } + constexpr LoopId::Range loopIdsUpTo(unsigned hi) const { + return loopIdsUpTo(LoopId{hi}); + } + + /// Safely converts the arguments to a pair of (tensor,loop) identifiers. + constexpr TensorLoopId makeTensorLoopId(TensorId t, LoopId i) const { + assert(isValidTensorId(t) && isValidLoopId(i)); + return TensorLoopId{numTensors * i.value + t.value}; + } + constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { + return makeTensorLoopId(TensorId{t}, LoopId{i}); + } + +private: + unsigned numTensors; + unsigned numLoops; +}; + //===----------------------------------------------------------------------===// /// `TensorExp` identifiers. These are allocated by `Merger::addExp`, /// and serve as unique identifiers for the corresponding `TensorExp` object. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -196,7 +196,7 @@ if (!latticeMerger.isFilterLoop(i)) { // We only count non-filter loops as filter loops should be considered // a special type of parallel loops. - if (linalg::isReductionIterator(iteratorTypes[i])) + if (linalg::isReductionIterator(iteratorTypes[*i])) break; // terminate at first reduction nest++; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -257,12 +257,12 @@ unsigned getNumTensors() const { return tensors.size(); } - tensor_id::Range getTensorIds() const { - return tensor_id::Range(0, getNumTensors()); + TensorId::Range getTensorIds() const { + return TensorId::fromZeroUpTo(TensorId::Unchecked(getNumTensors())); } bool isOutputTensor(TensorId tid) const { - return hasOutput && tid == getNumTensors() - 1; + return hasOutput && *tid == getNumTensors() - 1; } bool isSparseOutput(TensorId tid) const { @@ -270,7 +270,7 @@ } bool isValidLevel(TensorId tid, Level lvl) const { - return tid < lvlTypes.size() && lvl < lvlTypes[tid].size(); + return *tid < lvlTypes.size() && lvl < lvlTypes[*tid].size(); } /// Prepares loop for iterating over `tensor[lvl]`, under the assumption @@ -331,9 +331,9 @@ /// TODO: why not do this computation when we first store the reassoc, /// instead of doing it every time we look it up? SmallVector getCollapseReassociation(TensorId tid, Level dstLvl) { - assert(tid < getNumTensors() && "Invalid TensorId"); + assert(*tid < getNumTensors() && "Invalid TensorId"); assert(collapseReassoc.size() == getNumTensors()); - if (const auto reassoc = collapseReassoc[tid]) { + if (const auto reassoc = collapseReassoc[*tid]) { // TODO: store the dstLvlRank in the LoopEmitter so that we can // check `dstLvl < dstLvlRank` at the top; and only here need to // assert that `reassoc.size() == dstLvlRank`. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -88,10 +88,10 @@ std::pair LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, TensorId tid, Level lvl) { - assert(isSparseSlices[tid]); - Value slice = tensors[tid]; - Value offset = sliceOffsets[tid][lvl]; - Value stride = sliceStrides[tid][lvl]; + assert(isSparseSlices[*tid]); + Value slice = tensors[*tid]; + Value offset = sliceOffsets[*tid][lvl]; + Value stride = sliceStrides[*tid][lvl]; auto enc = getSparseTensorEncoding(slice.getType()); const auto [newCrd, crdRem] = @@ -109,7 +109,7 @@ // Second, coord_in_slice < length auto ltLength = builder.create(loc, arith::CmpIPredicate::ult, - newCrd, lvlSizes[tid][lvl]); + newCrd, lvlSizes[*tid][lvl]); conds.push_back(ltLength); // Third, rem == 0 (skip the check if stride is known to be 1). @@ -134,11 +134,11 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value crd) { - Value pos = lvl == 0 ? constantIndex(builder, loc, 0) : posits[tid][lvl - 1]; - Value mul = builder.create(loc, highs[tid][lvl], pos); - if (isSparseSlices[tid]) - crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl], - sliceStrides[tid][lvl], tensors[tid], lvl); + Value pos = lvl == 0 ? constantIndex(builder, loc, 0) : posits[*tid][lvl - 1]; + Value mul = builder.create(loc, highs[*tid][lvl], pos); + if (isSparseSlices[*tid]) + crd = toSliceCrd(builder, loc, crd, sliceOffsets[*tid][lvl], + sliceStrides[*tid][lvl], tensors[*tid], lvl); Value add = builder.create(loc, mul, crd); return add; } @@ -146,7 +146,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value pLo, Value pHi) { - const auto coordinates = coordinatesBuffers[tid][lvl]; + const auto coordinates = coordinatesBuffers[*tid][lvl]; const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo); auto whileOp = builder.create( loc, builder.getIndexType(), pLo, @@ -193,15 +193,15 @@ for (unsigned i = 0; i < reassocSize; i++) { const Level srcLvl = reassoc[i]; // A load on the coordinates array yields the coordinate. - const Value mem = coordinatesBuffers[tid][srcLvl]; + const Value mem = coordinatesBuffers[*tid][srcLvl]; /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - const Value pos = posits[tid][dstLvl]; + const Value pos = posits[*tid][dstLvl]; const Value off = genIndexLoad(builder, loc, mem, pos); // Linearized the coordinates within the same collapse reassociation. crd = builder.create(loc, crd, off); if (i != reassocSize - 1) { crd = builder.create(loc, crd, - this->lvlSizes[tid][reassoc[i + 1]]); + this->lvlSizes[*tid][reassoc[i + 1]]); } } return crd; @@ -249,7 +249,7 @@ // Initialize nested types of `TensorId`-indexed fields. for (const TensorId tid : getTensorIds()) { - const Value t = tensors[tid]; + const Value t = tensors[*tid]; // a scalar or 0-dimension tensors if (isZeroRankedTensorOrScalar(t.getType())) continue; @@ -260,10 +260,10 @@ // TODO: Supports more kinds of sparse tensors. // FIXME: We should instead lower reshape operations on sparse tensors to // view change. - collapseReassoc[tid] = reshape.getReassociation(); + collapseReassoc[*tid] = reshape.getReassociation(); rtp = reshape.getSrcType(); // Overwrites the tensor to the source tensor of reshape operations. - tensors[tid] = reshape.getSrc(); + tensors[*tid] = reshape.getSrc(); } const SparseTensorType stt(rtp); const Level lvlRank = stt.getLvlRank(); @@ -271,29 +271,29 @@ // it based on lvl size. if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) { const auto enc = stt.getEncoding(); - isSparseSlices[tid] = enc.isSlice(); + isSparseSlices[*tid] = enc.isSlice(); for (auto lvlTp : enc.getDimLevelType()) - lvlTypes[tid].push_back(lvlTp); + lvlTypes[*tid].push_back(lvlTp); } else { - lvlTypes[tid].assign(lvlRank, DimLevelType::Dense); + lvlTypes[*tid].assign(lvlRank, DimLevelType::Dense); } // Initialize using empty value. - lvlSizes[tid].assign(lvlRank, Value()); - highs[tid].assign(lvlRank, Value()); - segHi[tid].assign(lvlRank, Value()); - posits[tid].assign(lvlRank, Value()); - coords[tid].assign(lvlRank, Value()); - positionsBuffers[tid].assign(lvlRank, Value()); - coordinatesBuffers[tid].assign(lvlRank, Value()); - sliceOffsets[tid].assign(lvlRank, Value()); - sliceStrides[tid].assign(lvlRank, Value()); + lvlSizes[*tid].assign(lvlRank, Value()); + highs[*tid].assign(lvlRank, Value()); + segHi[*tid].assign(lvlRank, Value()); + posits[*tid].assign(lvlRank, Value()); + coords[*tid].assign(lvlRank, Value()); + positionsBuffers[*tid].assign(lvlRank, Value()); + coordinatesBuffers[*tid].assign(lvlRank, Value()); + sliceOffsets[*tid].assign(lvlRank, Value()); + sliceStrides[*tid].assign(lvlRank, Value()); - dependentLvlMap[tid].assign(lvlRank, - std::vector>()); + dependentLvlMap[*tid].assign(lvlRank, + std::vector>()); if (dimGetter) for (Level l = 0; l < lvlRank; l++) - dependentLvlMap[tid][l] = dimGetter(tid, l); + dependentLvlMap[*tid][l] = dimGetter(tid, l); } // Construct the inverse of the `topSort` from the sparsifier. @@ -301,7 +301,7 @@ // used in loop emitter. // FIXME: This map should be maintained outside loop emitter. for (LoopOrd n = 0; n < numLoops; n++) - loopIdToOrd[topSort[n]] = n; + loopIdToOrd[*(topSort[n])] = n; } void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, @@ -313,7 +313,7 @@ // * get/compute the level-size, which is also used as the upper-bound // on positions. for (const TensorId t : getTensorIds()) { - const Value tensor = tensors[t]; + const Value tensor = tensors[*t]; const auto rtp = tensor.getType().dyn_cast(); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and @@ -329,18 +329,18 @@ // Scan all levels of current tensor. for (Level l = 0; l < lvlRank; l++) { // This should be called only once at beginning. - assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] && - !highs[t][l]); - const auto lvlTp = lvlTypes[t][l]; + assert(!positionsBuffers[*t][l] && !coordinatesBuffers[*t][l] && + !highs[*t][l]); + const auto lvlTp = lvlTypes[*t][l]; // Handle sparse storage schemes. if (isCompressedDLT(lvlTp)) { // Generate sparse primitives to obtain positions and coordinates. - positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l); - coordinatesBuffers[t][l] = + positionsBuffers[*t][l] = genToPositions(builder, loc, tensor, l); + coordinatesBuffers[*t][l] = genToCoordinates(builder, loc, tensor, l, cooStart); } else if (isSingletonDLT(lvlTp)) { // Singleton level, fetch coordinates. - coordinatesBuffers[t][l] = + coordinatesBuffers[*t][l] = genToCoordinates(builder, loc, tensor, l, cooStart); } else { // Dense level, nothing to fetch. @@ -353,10 +353,10 @@ Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, toOrigDim(enc, l)); // Find upper bound in current dimension. - highs[t][l] = lvlSizes[t][l] = lvlSz; - if (isSparseSlices[t]) { - sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l); - sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l); + highs[*t][l] = lvlSizes[*t][l] = lvlSz; + if (isSparseSlices[*t]) { + sliceOffsets[*t][l] = genSliceOffset(builder, loc, tensors[*t], l); + sliceStrides[*t][l] = genSliceStride(builder, loc, tensors[*t], l); } } @@ -382,11 +382,11 @@ if (isOutput && updater) denseVal = updater(builder, loc, denseVal, tensor); - valBuffer[t] = denseVal; + valBuffer[*t] = denseVal; } else { // Annotated sparse tensors. // We also need the value buffer for all-dense annotated "sparse" tensors. - valBuffer[t] = genToValues(builder, loc, tensor); + valBuffer[*t] = genToValues(builder, loc, tensor); } // NOTE: we can also prepare for 0 lvl here in advance, this will hoist // some loop preparation from tensor iteration, but will also (undesirably) @@ -452,8 +452,8 @@ // TODO: this check for validity of the (t,l) pairs should be // checked/enforced at the callsites, if possible. assert(isValidLevel(t, l)); - assert(!coords[t][l]); // We cannot re-enter the same level - const auto lvlTp = lvlTypes[t][l]; + assert(!coords[*t][l]); // We cannot re-enter the same level + const auto lvlTp = lvlTypes[*t][l]; const bool isSparse = isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp); // Must be a recognizable level-type. assert(isSparse || isDenseDLT(lvlTp)); @@ -474,14 +474,14 @@ const Level srcLvl = reassoc.front(); const Value step = constantIndex(builder, loc, 1); /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - const Value lo = isSparseInput ? posits[tid][srcLvl] // current position + const Value lo = isSparseInput ? posits[*tid][srcLvl] // current position : loopSeqStack.back(); // universal index - const Value hi = highs[tid][srcLvl]; + const Value hi = highs[*tid][srcLvl]; Operation *loop = nullptr; Value iv; if (isParallel) { - assert(collapseReassoc[tid] == nullptr); + assert(collapseReassoc[*tid] == nullptr); scf::ParallelOp parOp = builder.create(loc, lo, hi, step, reduc); builder.setInsertionPointToStart(parOp.getBody()); @@ -513,18 +513,18 @@ Value crd; if (isSparseInput) { - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[*tid].getType())); // For COO, the position is the same across consecutive levels. /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - llvm::for_each(reassoc, - [this, tid, iv](Level srcLvl) { posits[tid][srcLvl] = iv; }); + llvm::for_each( + reassoc, [this, tid, iv](Level srcLvl) { posits[*tid][srcLvl] = iv; }); crd = genSparseCrd(builder, loc, tid, dstLvl); } else { // Dense tensor, the coordinate is the inducation variable. crd = iv; } - if (isSparseSlices[tid] && isSparseInput) { + if (isSparseSlices[*tid] && isSparseInput) { // For sparse level slices, we need to filter out invalid coordinates that // are not included in the slice. SmallVector types; @@ -553,7 +553,7 @@ } assert(crd); - coords[tid][srcLvl] = crd; + coords[*tid][srcLvl] = crd; // NOTE: we can also prepare for next level here in advance // Push the loop into stack loopStack.emplace_back(ArrayRef(tid), ArrayRef(srcLvl), loop, @@ -568,17 +568,17 @@ OpBuilder &builder, Location loc, TensorId tid, Level lvl, AffineExpr affine, MutableArrayRef reduc) { assert(isValidLevel(tid, lvl)); - assert(!affine.isa() && !isDenseDLT(lvlTypes[tid][lvl])); + assert(!affine.isa() && !isDenseDLT(lvlTypes[*tid][lvl])); // We can not re-enter the same level. - assert(!coords[tid][lvl]); + assert(!coords[*tid][lvl]); // TODO: We should instead use a whileOp for filter loop to allow early // break when exceeding (for ordered levels). // TODO: There are many other potiential opportunities that we might apply in // the future. E.g., we could use binary search to locate positions. const Value step = constantIndex(builder, loc, 1); - const Value pLo = posits[tid][lvl]; - const Value pHi = highs[tid][lvl]; + const Value pLo = posits[*tid][lvl]; + const Value pHi = highs[*tid][lvl]; scf::ForOp forOp = builder.create(loc, pLo, pHi, step, reduc); // In-place update on the reduction variable vector. @@ -589,11 +589,11 @@ builder.setInsertionPointToStart(forOp.getBody()); // The induction variable gives the position. const Value pos = forOp.getInductionVar(); - posits[tid][lvl] = pos; + posits[*tid][lvl] = pos; // Generating a load on the coordinates array yields the crd. - const Value mem = coordinatesBuffers[tid][lvl]; + const Value mem = coordinatesBuffers[*tid][lvl]; const Value crd = genIndexLoad(builder, loc, mem, pos); - coords[tid][lvl] = crd; + coords[*tid][lvl] = crd; // Generate an if-condition to filter out coordinates that are not // equal to the result of the affine expression. @@ -633,10 +633,10 @@ void LoopEmitter::genDenseAffineAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, AffineExpr lvlExpr) { - assert(isDenseDLT(lvlTypes[tid][lvl])); + assert(isDenseDLT(lvlTypes[*tid][lvl])); // For dense levels, the level-coordinate also serves as the position. Value lvlCrd = genAffine(builder, loc, lvlExpr); - posits[tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd); + posits[*tid][lvl] = genAddress(builder, loc, tid, lvl, lvlCrd); } Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( @@ -648,17 +648,17 @@ // Construct the while-loop with a parameter for each coordinate. const Type indexType = builder.getIndexType(); for (auto [tid, lvl] : llvm::zip(tids, lvls)) { - const auto lvlTp = lvlTypes[tid][lvl]; + const auto lvlTp = lvlTypes[*tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, lvl); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { - if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { + if (!isUniqueDLT(lvlTypes[*tid][reassoc[i]])) { // This is the segment high for each non-unique levels. types.push_back(indexType); operands.push_back(constantIndex(builder, loc, 0)); } } - const auto pos = posits[tid][reassoc.front()]; + const auto pos = posits[*tid][reassoc.front()]; assert(pos); types.push_back(indexType); operands.push_back(pos); @@ -688,19 +688,19 @@ unsigned o = 0; for (auto [t, lvl] : llvm::zip(tids, lvls)) { const TensorId tid = t; // Why `t` can not be captured by lambda? - const auto lvlTp = lvlTypes[tid][lvl]; + const auto lvlTp = lvlTypes[*tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, lvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[*tid].getType())); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { - if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { + if (!isUniqueDLT(lvlTypes[*tid][reassoc[i]])) { // Links the SSA chain for segHi. - segHi[tid][reassoc[i]] = after->getArgument(o++); + segHi[*tid][reassoc[i]] = after->getArgument(o++); } } Value op1 = before->getArgument(o); // We used the first level bound as the bound the collapsed set of levels. - Value op2 = highs[tid][reassoc.front()]; + Value op2 = highs[*tid][reassoc.front()]; Value opc = builder.create(loc, arith::CmpIPredicate::ult, op1, op2); cond = cond ? builder.create(loc, cond, opc) : opc; @@ -709,7 +709,7 @@ // For COO, the position is the same across consecutive levels. /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. llvm::for_each(reassoc, [this, tid, pos](Level srcLvl) { - posits[tid][srcLvl] = pos; + posits[*tid][srcLvl] = pos; }); } } @@ -722,15 +722,15 @@ unsigned i = 0; for (auto [tid, lvl] : llvm::zip(tids, lvls)) { // Prepares for next level. - const auto lvlTp = lvlTypes[tid][lvl]; + const auto lvlTp = lvlTypes[*tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { - coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); - if (isSparseSlices[tid]) { + coords[*tid][lvl] = genSparseCrd(builder, loc, tid, lvl); + if (isSparseSlices[*tid]) { auto [trans, pred] = - genSliceLegitPredicate(builder, loc, coords[tid][lvl], tid, lvl); + genSliceLegitPredicate(builder, loc, coords[*tid][lvl], tid, lvl); slicesPreds.emplace_back(pred, i); // Updates to the relative coordinate to the slice. - coords[tid][lvl] = trans; + coords[*tid][lvl] = trans; } i++; } @@ -773,9 +773,9 @@ // Finds the minimum coordinate if (!needsUniv) { for (auto [tid, lvl] : llvm::zip(tids, lvls)) { - const auto lvlTp = lvlTypes[tid][lvl]; + const auto lvlTp = lvlTypes[*tid][lvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { - const auto crd = coords[tid][lvl]; + const auto crd = coords[*tid][lvl]; if (min) { Value cmp = builder.create( loc, arith::CmpIPredicate::ult, crd, min); @@ -798,7 +798,7 @@ for (auto [tid, dstLvl] : llvm::zip(tids, lvls)) { const auto reassoc = getCollapseReassociation(tid, dstLvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[*tid].getType())); // TODO: Refactors this into smaller functions. // NOTE: For all the collapsed level (except for the last one, that is why // the loop ends with `reassoc.size() - 1`), as each iteration is advanced @@ -815,9 +815,9 @@ // the first iteration does not invalidate segHi[0] and segHi[1] for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { const Level srcLvl = reassoc[i]; - if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { - const Value pos = posits[tid][srcLvl]; - const auto oldSegHi = segHi[tid][srcLvl]; + if (!isUniqueDLT(lvlTypes[*tid][srcLvl])) { + const Value pos = posits[*tid][srcLvl]; + const auto oldSegHi = segHi[*tid][srcLvl]; assert(oldSegHi); Value newSegHi = builder.create( loc, arith::CmpIPredicate::uge, pos, oldSegHi); @@ -826,20 +826,20 @@ { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(ifNewSegHi.thenBlock()); - builder.create(loc, - genSegmentHigh(builder, loc, tid, srcLvl, - pos, highs[tid][srcLvl])); + builder.create( + loc, genSegmentHigh(builder, loc, tid, srcLvl, pos, + highs[*tid][srcLvl])); // Else, resues the same segment high. builder.setInsertionPointToStart(ifNewSegHi.elseBlock()); builder.create(loc, oldSegHi); } - highs[tid][srcLvl + 1] = segHi[tid][srcLvl] = ifNewSegHi.getResult(0); + highs[*tid][srcLvl + 1] = segHi[*tid][srcLvl] = ifNewSegHi.getResult(0); } }; const auto srcLvl = reassoc.back(); - if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { - segHi[tid][srcLvl] = genSegmentHigh( - builder, loc, tid, srcLvl, posits[tid][srcLvl], highs[tid][srcLvl]); + if (!isUniqueDLT(lvlTypes[*tid][srcLvl])) { + segHi[*tid][srcLvl] = genSegmentHigh( + builder, loc, tid, srcLvl, posits[*tid][srcLvl], highs[*tid][srcLvl]); } } @@ -858,7 +858,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, TensorId tid, Level dstLvl) { assert(isValidLevel(tid, dstLvl)); - const auto lvlTp = lvlTypes[tid][dstLvl]; + const auto lvlTp = lvlTypes[*tid][dstLvl]; if (isDenseDLT(lvlTp)) return; @@ -868,22 +868,22 @@ for (const Level srcLvl : getCollapseReassociation(tid, dstLvl)) { // Either the first level, or the previous level has been set. /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. - assert(srcLvl == 0 || posits[tid][srcLvl - 1]); + assert(srcLvl == 0 || posits[*tid][srcLvl - 1]); if (!isCompressedDLT(lvlTp) && !isSingletonDLT(lvlTp)) continue; if (isCompressedDLT(lvlTp)) { - const Value mem = positionsBuffers[tid][srcLvl]; + const Value mem = positionsBuffers[*tid][srcLvl]; - const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; - posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); + const Value pLo = srcLvl == 0 ? c0 : posits[*tid][srcLvl - 1]; + posits[*tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); const Value pHi = builder.create(loc, pLo, c1); - highs[tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi); + highs[*tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi); return; } if (isSingletonDLT(lvlTp)) { - const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; - posits[tid][srcLvl] = pLo; + const Value pLo = srcLvl == 0 ? c0 : posits[*tid][srcLvl - 1]; + posits[*tid][srcLvl] = pLo; // If we are coiterating non-unique levels, then use pHi=segHi; // otherwise use pHi=pLo+1. @@ -891,9 +891,9 @@ // guarantee that segHi is defined: because we only generate segHi // whenever coiterating, in order to improve code quality for the // non-coiterating cases. - const auto parentSegHi = segHi[tid][srcLvl - 1]; - highs[tid][srcLvl] = - (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && parentSegHi) + const auto parentSegHi = segHi[*tid][srcLvl - 1]; + highs[*tid][srcLvl] = + (!isUniqueDLT(lvlTypes[*tid][srcLvl - 1]) && parentSegHi) ? parentSegHi : builder.create(loc, pLo, c1); return; @@ -912,17 +912,17 @@ // but may be needed for linearized codegen. assert(tids.size() == lvls.size()); for (auto [tid, lvl] : llvm::zip(tids, lvls)) { - if (isDenseDLT(lvlTypes[tid][lvl])) { - auto enc = getSparseTensorEncoding(tensors[tid].getType()); + if (isDenseDLT(lvlTypes[*tid][lvl])) { + auto enc = getSparseTensorEncoding(tensors[*tid].getType()); if (enc && !isSparseOutput(tid)) { - bool validPos = lvl == 0 || posits[tid][lvl - 1]; + bool validPos = lvl == 0 || posits[*tid][lvl - 1]; if (!validPos) { // We might not find the pos for the sparse output tensor as it is // unconditionally required by the sparsification. assert(isOutputTensor(tid)); continue; } - posits[tid][lvl] = + posits[*tid][lvl] = genAddress(builder, loc, tid, lvl, loopStack.back().iv); // NOTE: we can also prepare for next lvl here in advance } @@ -1002,11 +1002,11 @@ // finish the iteration on a sparse tensor for (auto [tid, lvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { // Reset to null. - coords[tid][lvl] = Value(); - posits[tid][lvl] = Value(); + coords[*tid][lvl] = Value(); + posits[*tid][lvl] = Value(); // Dense level, high is fixed. - if (!isDenseDLT(lvlTypes[tid][lvl])) - highs[tid][lvl] = Value(); + if (!isDenseDLT(lvlTypes[*tid][lvl])) + highs[*tid][lvl] = Value(); } } @@ -1025,26 +1025,26 @@ SmallVector operands; Value one = constantIndex(builder, loc, 1); for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { - const auto lvlTp = lvlTypes[tid][dstLvl]; + const auto lvlTp = lvlTypes[*tid][dstLvl]; if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, dstLvl); - assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); + assert(reassoc.size() == 1 || isUniqueCOOType(tensors[*tid].getType())); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { const Level srcLvl = reassoc[i]; - if (!isUniqueDLT(lvlTypes[tid][srcLvl])) { - operands.push_back(segHi[tid][srcLvl]); + if (!isUniqueDLT(lvlTypes[*tid][srcLvl])) { + operands.push_back(segHi[*tid][srcLvl]); o++; } } - const Value crd = coords[tid][dstLvl]; - const Value pos = posits[tid][dstLvl]; + const Value crd = coords[*tid][dstLvl]; + const Value pos = posits[*tid][dstLvl]; Value cmp = builder.create(loc, arith::CmpIPredicate::eq, crd, iv); // If the loop contains a coiteration with non-unique level, we fast // forward all the duplicated coords by setting the position to the // segment high. - Value add = !isUniqueDLT(lvlTypes[tid][reassoc.back()]) - ? segHi[tid][reassoc.back()] + Value add = !isUniqueDLT(lvlTypes[*tid][reassoc.back()]) + ? segHi[*tid][reassoc.back()] : builder.create(loc, pos, one); operands.push_back(builder.create(loc, cmp, add, pos)); @@ -1056,12 +1056,12 @@ // FIXME(wrengr): define a helper function to capture this idiom! const TensorId newTid = tid; llvm::for_each(reassoc, [this, newTid, newPos](Level srcLvl) { - posits[newTid][srcLvl] = newPos; + posits[*newTid][srcLvl] = newPos; }); // The coordinate is invalid now. - coords[tid][dstLvl] = nullptr; + coords[*tid][dstLvl] = nullptr; // The segment high is invalid now. - segHi[tid][dstLvl] = nullptr; + segHi[*tid][dstLvl] = nullptr; // highs remains unchanged. } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -943,7 +943,7 @@ // one for loop? // FIXME(wrengr): what is this "ld" supposed to be really? const Level ld = op.getOrder() ? op.getOrder()->getDimPosition(l) : l; - const SmallVector tids{0}; + const SmallVector tids{TensorId::Unchecked(0)}; loopEmitter.enterNewLoopSeq(rewriter, loc, tids, ld); // Note that reduc will be taken care of by loop emitter and get updated // in place. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -131,7 +131,7 @@ LoopId ldx, bool &isAtLoop) { switch (a.getKind()) { case AffineExprKind::DimId: { - const LoopId i = a.cast().getPosition(); + const LoopId i = LoopId::Unchecked(a.cast().getPosition()); if (i == ldx) { isAtLoop = true; // Must be invariant if we are at the given loop. @@ -221,7 +221,7 @@ /// compound affine sparse level, and it will be incremented by one when /// used. static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, - DimLevelType dlt, loop_id::Iterator &filterLdx, + DimLevelType dlt, LoopId::Iterator &filterLdx, bool setLvlFormat = true) { switch (a.getKind()) { case AffineExprKind::DimId: { @@ -406,8 +406,8 @@ /// supports affine addition index expression. static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { bool annotated = false; - const loop_id::Range filterLdxRange = env.merger().getFilterLoopIds(); - loop_id::Iterator filterLdx = filterLdxRange.begin(); + const LoopId::Range filterLdxRange = env.merger().getFilterLoopIds(); + LoopId::Iterator filterLdx = filterLdxRange.begin(); for (OpOperand &t : env.op()->getOpOperands()) { const TensorId tid = env.makeTensorId(t.getOperandNumber()); const auto map = env.op().getMatchingIndexingMap(&t); @@ -459,10 +459,10 @@ std::vector parIt; // parallel iterator with 0 degree std::vector filterIt; // filter loop with 0 degree for (const LoopId i : env.merger().getLoopIds()) { - if (inDegree[i] == 0) { + if (inDegree[*i] == 0) { if (env.merger().isFilterLoop(i)) filterIt.push_back(i); - else if (linalg::isReductionIterator(iteratorTypes[i])) + else if (linalg::isReductionIterator(iteratorTypes[*i])) redIt.push_back(i); else parIt.push_back(i); @@ -495,10 +495,10 @@ it.pop_back(); // Update in-degree, and push 0-degree node into worklist. for (const LoopId dst : env.merger().getLoopIds()) { - if (adjM[src][dst] && --inDegree[dst] == 0) { + if (adjM[*src][*dst] && --inDegree[*dst] == 0) { if (env.merger().isFilterLoop(dst)) filterIt.push_back(dst); - else if (linalg::isReductionIterator(iteratorTypes[dst])) + else if (linalg::isReductionIterator(iteratorTypes[*dst])) redIt.push_back(dst); else parIt.push_back(dst); @@ -525,7 +525,7 @@ if (!a && !b) { // Recursion leaf. assert(fidx && tidx); - const LoopId f = *fidx, t = *tidx; + const unsigned f = **fidx, t = **tidx; if (!adjM[f][t]) { adjM[f][t] = true; inDegree[t]++; @@ -537,7 +537,7 @@ switch (toExpand.getKind()) { case AffineExprKind::DimId: { const std::optional idx{ - toExpand.cast().getPosition()}; + LoopId::Unchecked(toExpand.cast().getPosition())}; if (toExpand == a) addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx); else // toExpand == b @@ -582,7 +582,7 @@ finder.setPickedIterType(utils::IteratorType::parallel); finder.walkPostOrder(fa); fa = finder.getDimExpr(); - fldx = finder.getDimExpr().getPosition(); + fldx = LoopId::Unchecked(finder.getDimExpr().getPosition()); } if (!ta.isa()) { // Heuristic: we prefer reduction loop for rhs to reduce the chance @@ -590,7 +590,7 @@ finder.setPickedIterType(utils::IteratorType::reduction); finder.walkPostOrder(ta); ta = finder.getDimExpr(); - tldx = finder.getDimExpr().getPosition(); + tldx = LoopId::Unchecked(finder.getDimExpr().getPosition()); } } } @@ -698,9 +698,9 @@ const LoopId tldx = env.makeLoopId(texp.getPosition()); // d_x > d_y - if (!adjM[fldx][tldx]) { - adjM[fldx][tldx] = true; - inDegree[tldx]++; + if (!adjM[*fldx][*tldx]) { + adjM[*fldx][*tldx] = true; + inDegree[*tldx]++; } AffineDimCollector fCollector; @@ -713,18 +713,18 @@ const LoopId f = env.makeLoopId(fd.getPosition()); if (f == fldx) continue; - if (!adjM[f][fldx]) { - adjM[f][fldx] = true; - inDegree[fldx]++; + if (!adjM[*f][*fldx]) { + adjM[*f][*fldx] = true; + inDegree[*fldx]++; } } for (auto td : tCollector.dims) { const LoopId t = env.makeLoopId(td.getPosition()); if (t == tldx) continue; - if (!adjM[t][tldx]) { - adjM[t][tldx] = true; - inDegree[tldx]++; + if (!adjM[*t][*tldx]) { + adjM[*t][*tldx] = true; + inDegree[*tldx]++; } } // Since we only support affine addition, the order between two dim @@ -745,9 +745,9 @@ const LoopId t = env.makeLoopId(td.getPosition()); if (t == tldx) // skip d_y continue; - if (!adjM[f][t]) { - adjM[f][t] = true; - inDegree[t]++; + if (!adjM[*f][*t]) { + adjM[*f][*t] = true; + inDegree[*t]++; } } } @@ -794,8 +794,8 @@ if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) { for (const LoopId j : env.merger().getLoopIds()) if (isUndefDLT(env.dlt(tid, j))) { - adjM[i][j] = true; - inDegree[j]++; + adjM[*i][*j] = true; + inDegree[*j]++; } } else { assert(isDenseDLT(dltI) || isUndefDLT(dltI)); @@ -888,7 +888,7 @@ const auto stt = getSparseTensorType(t->get()); if (stt.hasEncoding()) { // For sparse tensors we only push the last-level's position onto `args`. - const auto pos = env.emitter().getPosits()[tid].back(); + const auto pos = env.emitter().getPosits()[*tid].back(); assert(pos); args.push_back(pos); } else { @@ -901,7 +901,7 @@ args.push_back(lvlCrd); } } - return env.emitter().getValBuffer()[tid]; + return env.emitter().getValBuffer()[*tid]; } /// Generates insertion code to implement dynamic tensor load. @@ -1027,7 +1027,7 @@ // Load during insertion. linalg::GenericOp op = env.op(); - OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); + OpOperand *t = &op->getOpOperand(*(env.exp(exp).tensor)); if (env.isSparseOutput(t)) { if (env.isCustomReduc()) return genInsertionLoadReduce(env, builder, t); @@ -1160,9 +1160,9 @@ return; if (env.exp(exp).kind == TensorExp::Kind::kTensor) { // Inspect tensor indices. - bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId; + bool isAtLoop = !ldx.isValid(); linalg::GenericOp op = env.op(); - OpOperand &t = op->getOpOperand(env.exp(exp).tensor); + OpOperand &t = op->getOpOperand(*(env.exp(exp).tensor)); const auto map = op.getMatchingIndexingMap(&t); const auto stt = getSparseTensorType(t.get()); const Level lvlRank = stt.getLvlRank(); @@ -1311,7 +1311,7 @@ // tids/lvls must only have one value because filter loops only // corresponding to the one and only sparse tensor level. assert(isSparse && tids.size() == 1 && lvls.size() == 1); - OpOperand *t = &op->getOpOperand(tid); + OpOperand *t = &op->getOpOperand(*tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. // FIXME: `toOrigDim` is deprecated. @@ -1409,7 +1409,7 @@ Value clause; if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { assert(lvl.has_value()); - const Value crd = env.emitter().getCoords()[tid][*lvl]; + const Value crd = env.emitter().getCoords()[*tid][*lvl]; const Value lvar = env.getLoopVar(ldx); clause = builder.create(loc, arith::CmpIPredicate::eq, crd, lvar); @@ -1509,8 +1509,8 @@ Level startLvl) { // TODO: Handle affine expression on output tensor. linalg::GenericOp op = env.op(); - assert(tid < op.getNumDpsInputs()); - OpOperand *input = op.getDpsInputOperands()[tid]; + assert(*tid < op.getNumDpsInputs()); + OpOperand *input = op.getDpsInputOperands()[*tid]; const auto lvlExprs = op.getMatchingIndexingMap(input).getResults(); const auto enc = getSparseTensorEncoding(input->get().getType()); if (enc) { @@ -1537,7 +1537,7 @@ // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two // levels can be determined before loops. const TensorId hi = env.makeTensorId(env.op().getNumDpsInputs()); - for (const TensorId tid : tensor_id::Range(0, hi)) + for (const TensorId tid : TensorId::fromZeroUpTo(hi)) genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } @@ -1556,7 +1556,7 @@ env.merger().foreachTensorLoopId( li, [&, ldx](TensorLoopId b, TensorId tid, std::optional lvl, DimLevelType dlt, bool isIdxReduc) { - if (simple[b]) { + if (simple[*b]) { if (isIdxReduc) { tids.push_back(tid); lvls.push_back(*lvl); @@ -1586,10 +1586,10 @@ } else { assert(isUndefDLT(dlt)); linalg::GenericOp op = env.op(); - if (tid >= op.getNumDpsInputs()) + if (*tid >= op.getNumDpsInputs()) // We only handle affine expression on input tensors (for now). return; - OpOperand *operand = &op->getOpOperand(tid); + OpOperand *operand = &op->getOpOperand(*tid); const auto stt = getSparseTensorType(operand->get()); // Non-annotated dense tensors requires no special handling. if (!stt.hasEncoding()) @@ -1726,8 +1726,7 @@ // Construct iteration lattices for current loop index, with L0 at top. const LoopId idx = env.topSortAt(at); - const LoopId ldx = at == 0 ? ::mlir::sparse_tensor::detail::kInvalidId - : env.topSortAt(at - 1); + const LoopId ldx = at == 0 ? LoopId::Invalid() : env.topSortAt(at - 1); const LatSetId lts = env.merger().optimizeSet(env.merger().buildLattices(exp, idx)); @@ -1917,8 +1916,8 @@ auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - rewriter.updateRootInPlace(env.op(), - [&]() { env.op()->setOperand(tid, convert); }); + rewriter.updateRootInPlace( + env.op(), [&]() { env.op()->setOperand(*tid, convert); }); rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -98,22 +98,25 @@ // Constructors. //===----------------------------------------------------------------------===// -TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v, - Operation *o) - : kind(k), val(v), op(o) { +constexpr TensorExp::TensorExp(TensorId t) + : kind(TensorExp::Kind::kTensor), tensor(t), val(), op(nullptr) {} + +constexpr TensorExp::TensorExp(LoopId i) + : kind(TensorExp::Kind::kLoopVar), loop(i), val(), op(nullptr) {} + +constexpr TensorExp::TensorExp(Value v) + : kind(TensorExp::Kind::kInvariant), children(), val(v), op(nullptr) {} + +constexpr TensorExp::TensorExp(Kind k, ExprId e0, ExprId e1, Value v, + Operation *o) + : kind(k), children(e0, e1), val(v), op(o) { switch (kind) { // Leaf. case TensorExp::Kind::kTensor: - assert(x != detail::kInvalidId && !y.isValid() && !v && !o); - tensor = x; - return; case TensorExp::Kind::kInvariant: - assert(x == detail::kInvalidId && !y.isValid() && v && !o); - return; case TensorExp::Kind::kLoopVar: - assert(x != detail::kInvalidId && !y.isValid() && !v && !o); - loop = x; - return; + assert(false && "expected non-leaf expression kind"); + break; // Unary operations. case TensorExp::Kind::kAbsF: case TensorExp::Kind::kAbsC: @@ -135,9 +138,7 @@ case TensorExp::Kind::kNegI: case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: - assert(x != detail::kInvalidId && !y.isValid() && !v && !o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && !e1.isValid() && !v && !o); return; case TensorExp::Kind::kTruncF: case TensorExp::Kind::kExtF: @@ -150,22 +151,16 @@ case TensorExp::Kind::kCastIdx: case TensorExp::Kind::kTruncI: case TensorExp::Kind::kBitCast: - assert(x != detail::kInvalidId && !y.isValid() && v && !o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && !e1.isValid() && v && !o); return; case TensorExp::Kind::kBinaryBranch: case TensorExp::Kind::kSelect: - assert(x != detail::kInvalidId && !y.isValid() && !v && o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && !e1.isValid() && !v && o); return; case TensorExp::Kind::kUnary: - // No assertion on y can be made, as the branching paths involve both + // No assertion on `e1` can be made, as the branching paths involve both // a unary (`mapSet`) and binary (`disjSet`) pathway. - assert(x != detail::kInvalidId && !v && o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && !v && o); return; // Binary operations. case TensorExp::Kind::kMulF: @@ -187,15 +182,11 @@ case TensorExp::Kind::kShrS: case TensorExp::Kind::kShrU: case TensorExp::Kind::kShlI: - assert(x != detail::kInvalidId && y.isValid() && !v && !o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && e1.isValid() && !v && !o); return; case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: - assert(x != detail::kInvalidId && y.isValid() && !v && o); - children.e0 = ExprId{x}; - children.e1 = y; + assert(e0.isValid() && e1.isValid() && !v && o); return; } llvm_unreachable("unexpected kind"); @@ -205,19 +196,21 @@ unsigned numFilterLoops) : outTensor(numInputOutputTensors - 1), syntheticTensor(numInputOutputTensors), - numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), - numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), - lvlTypes(numTensors, - std::vector(numLoops, DimLevelType::Undef)), - loopToLvl(numTensors, - std::vector>(numLoops, std::nullopt)), - lvlToLoop(numTensors, - std::vector>(numLoops, std::nullopt)), - loopToDependencies(numLoops, std::vector>( - numTensors, std::nullopt)), - levelToDependentIdx(numTensors, std::vector>( - numLoops, std::vector())), - loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} + factory(numInputOutputTensors + 1, numNativeLoops + numFilterLoops), + numNativeLoops(numNativeLoops), hasSparseOut(false), + lvlTypes(getNumTensors(), + std::vector(getNumLoops(), DimLevelType::Undef)), + loopToLvl(getNumTensors(), + std::vector>(getNumLoops(), std::nullopt)), + lvlToLoop(getNumTensors(), std::vector>( + getNumLoops(), std::nullopt)), + loopToDependencies(getNumLoops(), std::vector>( + getNumTensors(), std::nullopt)), + levelToDependentIdx(getNumTensors(), + std::vector>( + getNumLoops(), std::vector())), + loopBounds(getNumLoops(), + std::make_pair(TensorId{getNumTensors()}, getNumLoops())) {} //===----------------------------------------------------------------------===// // Lattice methods. @@ -226,51 +219,48 @@ ExprId Merger::addTensorExp(TensorId t) { assert(isValidTensorId(t)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kTensor, t, ExprId(), Value(), - nullptr); + tensorExps.emplace_back(t); return eNew; } ExprId Merger::addLoopVarExp(LoopId i) { assert(isValidLoopId(i)); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kLoopVar, i, ExprId(), Value(), - nullptr); + tensorExps.emplace_back(i); return eNew; } ExprId Merger::addInvariantExp(Value v) { const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(TensorExp::Kind::kInvariant, detail::kInvalidId, - ExprId(), v, nullptr); + tensorExps.emplace_back(v); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e0.value, e1, Value(), op); + tensorExps.emplace_back(k, e0, e1, Value(), op); return eNew; } ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op) { assert(k > TensorExp::Kind::kLoopVar); const ExprId eNew(tensorExps.size()); - tensorExps.emplace_back(k, e.value, ExprId(), v, op); + tensorExps.emplace_back(k, e, ExprId(), v, op); return eNew; } LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { const LatPointId pNew(latPoints.size()); - const unsigned size = numLoops * numTensors; + const unsigned size = getNumLoops() * getNumTensors(); const TensorLoopId b = makeTensorLoopId(t, i); latPoints.emplace_back(size, e); - latPoints[pNew.value].bits.set(b); + latPoints[pNew.value].bits.set(b.value); return pNew; } LatPointId Merger::addLat(const BitVector &bits, ExprId e) { - assert(bits.size() == numLoops * numTensors); + assert(bits.size() == getNumLoops() * getNumTensors()); const LatPointId pNew(latPoints.size()); latPoints.emplace_back(bits, e); return pNew; @@ -438,8 +428,8 @@ const BitVector &bitsj = lat(j).bits; assert(bitsi.size() == bitsj.size()); if (bitsi.count() > bitsj.count()) { - for (const TensorLoopId b : tensor_loop_id::Range(bitsj)) - if (!bitsi[b]) + for (const TensorLoopId b : TensorLoopId::Range(bitsj)) + if (!bitsi[b.value]) return false; return true; } @@ -591,7 +581,7 @@ } bool Merger::hasAnySparse(const BitVector &bits) const { - for (const TensorLoopId b : tensor_loop_id::Range(bits)) { + for (const TensorLoopId b : TensorLoopId::Range(bits)) { const auto dlt = getDimLevelType(b); if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) return true; @@ -601,8 +591,8 @@ bool Merger::hasSparseIdxReduction(const BitVector &bits) const { // TODO: return false on dense levels. - for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && isLvlWithNonTrivialIdxExp(b)) + for (const TensorLoopId b : TensorLoopId::Range(bits)) + if (isLvlWithNonTrivialIdxExp(b)) return true; return false; } @@ -719,13 +709,13 @@ llvm::dbgs() << "synthetic_"; else if (expr.tensor == outTensor) llvm::dbgs() << "output_"; - llvm::dbgs() << "tensor_" << expr.tensor; + llvm::dbgs() << "tensor_" << expr.tensor.value; break; case TensorExp::Kind::kInvariant: llvm::dbgs() << "invariant"; break; case TensorExp::Kind::kLoopVar: - llvm::dbgs() << "loopvar_" << expr.loop; + llvm::dbgs() << "loopvar_" << expr.loop.value; break; // Unary operations. case TensorExp::Kind::kAbsF: @@ -817,9 +807,9 @@ } void Merger::dumpBits(const BitVector &bits) const { - for (const TensorLoopId b : tensor_loop_id::Range(bits)) { - const TensorId t = tensor(b); - const LoopId i = loop(b); + for (const TensorLoopId b : TensorLoopId::Range(bits)) { + const auto t = tensor(b).value; + const auto i = loop(b).value; const auto dlt = lvlTypes[t][i]; if (isLvlWithNonTrivialIdxExp(b)) llvm::dbgs() << " DEP_" << t << "_" << i; @@ -1126,7 +1116,7 @@ // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { - OpOperand &t = op->getOpOperand(tid); + OpOperand &t = op->getOpOperand(tid.value); if (!op.isScalar(&t)) return addTensorExp(tid); v = t.get(); // get scalar value diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -203,7 +203,7 @@ BitVector loopsToBits(const std::vector> &loops) { BitVector testBits = BitVector(merger.getNumTensors(), false); for (auto [loop, tensor] : loops) - testBits.set(merger.makeTensorLoopId(tensor, loop)); + testBits.set(*(merger.makeTensorLoopId(tensor, loop))); return testBits; }