diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -314,6 +314,10 @@ /// Get the total number of tensors (including the output-tensor and /// synthetic-tensor). constexpr unsigned getNumTensors() const { return numTensors; } + /// Get the range of all tensor identifiers. + constexpr tensor_id::Range getTensorIds() const { + return tensor_id::Range(0, numTensors); + } /// Get the total number of loops (native loops + filter loops). constexpr unsigned getNumLoops() const { return numLoops; } @@ -323,9 +327,17 @@ constexpr unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; } - /// Get the identifier of the first filter-loop. - constexpr LoopId getStartingFilterLoopId() const { - return getNumNativeLoops(); + /// Get the range of all loop identifiers. + constexpr loop_id::Range getLoopIds() const { + return loop_id::Range(0, numLoops); + } + /// Get the range of native-loop identifiers. + constexpr loop_id::Range getNativeLoopIds() const { + return loop_id::Range(0, numNativeLoops); + } + /// Get the range of filter-loop identifiers. + constexpr loop_id::Range getFilterLoopIds() const { + return loop_id::Range(numNativeLoops, numLoops); } /// Returns true if `b` is the `i`th loop of the output tensor. diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/MergerNewtypes.h @@ -29,6 +29,9 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGERNEWTYPES_H_ +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/iterator.h" + #include #include @@ -53,6 +56,58 @@ /// `GenericOpSparsifier::matchAndRewrite`. using TensorId = unsigned; +// NOTE: We use this namespace to simulate having turned `TensorId` +// into a newtype, so that we can split the patch for adding the iterators +// from the patch for actually making it a newtype. +namespace tensor_id { +class Iterator; +class Range; +} // namespace tensor_id + +/// An iterator for `TensorId`. We define this as a separate class because +/// it wouldn't be generally safe/meaningful to define `TensorId::operator++`. +/// The ctor is private for similar reasons, so client code should create +/// iterators via `tensor_id::Range` instead. +class tensor_id::Iterator final + : public llvm::iterator_facade_base { + friend class tensor_id::Range; + explicit constexpr Iterator(TensorId tid) : tid(tid) {} + +public: + using llvm::iterator_facade_base::operator++; + Iterator &operator++() { + ++tid; + return *this; + } + const TensorId *operator->() const { return &tid; } + constexpr TensorId operator*() const { return tid; } + constexpr bool operator==(Iterator rhs) const { return tid == rhs.tid; } + constexpr bool operator!=(Iterator rhs) const { return !(*this == rhs); } + +private: + TensorId tid; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +/// An iterator range for `TensorId`. +class tensor_id::Range final { +public: + explicit constexpr Range(TensorId lo, TensorId hi) + : begin_(lo <= hi ? lo : hi), end_(hi) {} + constexpr Iterator begin() const { return begin_; } + constexpr Iterator end() const { return end_; } + constexpr bool empty() const { return begin_ == end_; } + +private: + Iterator begin_; + Iterator end_; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + //===----------------------------------------------------------------------===// /// Loop identifiers. /// @@ -78,6 +133,61 @@ /// are invariant identifiers). using LoopId = unsigned; +// NOTE: We use this namespace to simulate having turned `LoopId` into +// a newtype, so that we can split the patch for adding the iterators from +// the patch for actually making it a newtype. +namespace loop_id { +class Iterator; +class Range; +} // namespace loop_id + +/// An iterator for `LoopId`. We define this as a separate class because +/// it wouldn't be generally safe/meaningful to define `LoopId::operator++`. +/// The ctor is private for similar reasons, so client code should create +/// iterators via `loop_id::Range` instead. +class loop_id::Iterator final + : public llvm::iterator_facade_base { + friend class loop_id::Range; + explicit constexpr Iterator(LoopId i) : loop(i) {} + +public: + using llvm::iterator_facade_base::operator++; + Iterator &operator++() { + ++loop; + return *this; + } + const LoopId *operator->() const { return &loop; } + constexpr LoopId operator*() const { return loop; } + constexpr bool operator==(Iterator rhs) const { return loop == rhs.loop; } + constexpr bool operator!=(Iterator rhs) const { return !(*this == rhs); } + +private: + LoopId loop; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + +/// An iterator range for `LoopId`. +class loop_id::Range final { +public: + explicit constexpr Range(LoopId lo, LoopId hi) + : begin_((lo != detail::kInvalidId && hi != detail::kInvalidId) + ? (lo <= hi ? lo : hi) + : detail::kInvalidId), + end_(lo != detail::kInvalidId ? hi : detail::kInvalidId) {} + constexpr Iterator begin() const { return begin_; } + constexpr Iterator end() const { return end_; } + constexpr bool empty() const { return begin_ == end_; } + +private: + Iterator begin_; + Iterator end_; +}; +static_assert(std::is_trivially_copyable_v && + std::is_trivially_destructible_v); + //===----------------------------------------------------------------------===// /// A compressed representation of `std::pair`. /// The compression scheme is such that this also serves as an index @@ -85,6 +195,46 @@ /// just the implementation for a set of `TensorLoopId` values). using TensorLoopId = unsigned; +// NOTE: We use this namespace to simulate having turned `TensorLoopId` into +// a newtype, so that we can split the patch for adding the iterators from +// the patch for actually making it a newtype. +namespace tensor_loop_id { +class Iterator; +class Range; +} // namespace tensor_loop_id + +/// An iterator of the `TensorLoopId`s which are included/set in a `BitVector`. +class tensor_loop_id::Iterator final + : public llvm::iterator_adaptor_base< + Iterator, llvm::BitVector::const_set_bits_iterator, + // Since `const_set_bits_iterator` doesn't define its own + // `std::iterator_traits`, we must manually do so here. + /*iterator_category=*/std::forward_iterator_tag, + /*value_type=*/unsigned, + /*difference_type=*/std::ptrdiff_t, // irrelevant, since not + // random-access + /*pointer=*/const unsigned *, + /*reference=*/const unsigned &> { + friend class tensor_loop_id::Range; + explicit Iterator(llvm::BitVector::const_set_bits_iterator I) + : Iterator::iterator_adaptor_base(I) {} + +public: + TensorLoopId operator*() const { return TensorLoopId{*this->I}; } +}; + +/// An iterator range for the `TensorLoopId`s which are included/set +/// in a `BitVector`. +class tensor_loop_id::Range final { +public: + explicit Range(const llvm::BitVector &bits) : set_bits(bits.set_bits()) {} + Iterator begin() const { return Iterator{set_bits.begin()}; } + Iterator end() const { return Iterator{set_bits.end()}; } + +private: + llvm::iterator_range set_bits; +}; + //===----------------------------------------------------------------------===// /// `TensorExp` identifiers. These are allocated by `Merger::addExp`, /// and serve as unique identifiers for the corresponding `TensorExp` object. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -257,6 +257,10 @@ unsigned getNumTensors() const { return tensors.size(); } + tensor_id::Range getTensorIds() const { + return tensor_id::Range(0, getNumTensors()); + } + bool isOutputTensor(TensorId tid) const { return hasOutput && tid == getNumTensors() - 1; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -248,7 +248,7 @@ numTensors, std::vector>>()); // Initialize nested types of `TensorId`-indexed fields. - for (TensorId tid = 0; tid < numTensors; tid++) { + for (const TensorId tid : getTensorIds()) { const Value t = tensors[tid]; // a scalar or 0-dimension tensors if (isZeroRankedTensorOrScalar(t.getType())) @@ -312,7 +312,7 @@ // * get the positions and coordinates buffers // * get/compute the level-size, which is also used as the upper-bound // on positions. - for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { + for (const TensorId t : getTensorIds()) { const Value tensor = tensors[t]; const auto rtp = tensor.getType().dyn_cast(); if (!rtp) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -221,7 +221,7 @@ /// compound affine sparse level, and it will be incremented by one when /// used. static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, - DimLevelType dlt, LoopId &filterLdx, + DimLevelType dlt, loop_id::Iterator &filterLdx, bool setLvlFormat = true) { switch (a.getKind()) { case AffineExprKind::DimId: { @@ -237,9 +237,9 @@ case AffineExprKind::Mul: case AffineExprKind::Constant: { if (!isDenseDLT(dlt) && setLvlFormat) { - assert(isUndefDLT(merger.getDimLevelType(tid, filterLdx))); + assert(isUndefDLT(merger.getDimLevelType(tid, *filterLdx))); // Use a filter loop for sparse affine expression. - merger.setLevelAndType(tid, filterLdx, lvl, dlt); + merger.setLevelAndType(tid, *filterLdx, lvl, dlt); ++filterLdx; } @@ -406,8 +406,8 @@ /// supports affine addition index expression. static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { bool annotated = false; - // `filterLdx` may be mutated by `findAffine`. - LoopId filterLdx = env.merger().getStartingFilterLoopId(); + const loop_id::Range filterLdxRange = env.merger().getFilterLoopIds(); + loop_id::Iterator filterLdx = filterLdxRange.begin(); for (OpOperand &t : env.op()->getOpOperands()) { const TensorId tid = env.makeTensorId(t.getOperandNumber()); const auto map = env.op().getMatchingIndexingMap(&t); @@ -440,7 +440,7 @@ } } } - assert(filterLdx == env.merger().getNumLoops()); + assert(filterLdx == filterLdxRange.end()); return annotated; } @@ -458,8 +458,7 @@ std::vector redIt; // reduce iterator with 0 degree std::vector parIt; // parallel iterator with 0 degree std::vector filterIt; // filter loop with 0 degree - const LoopId numLoops = env.merger().getNumLoops(); - for (LoopId i = 0; i < numLoops; i++) { + for (const LoopId i : env.merger().getLoopIds()) { if (inDegree[i] == 0) { if (env.merger().isFilterLoop(i)) filterIt.push_back(i); @@ -495,7 +494,7 @@ env.topSortPushBack(src); it.pop_back(); // Update in-degree, and push 0-degree node into worklist. - for (LoopId dst = 0; dst < numLoops; dst++) { + for (const LoopId dst : env.merger().getLoopIds()) { if (adjM[src][dst] && --inDegree[dst] == 0) { if (env.merger().isFilterLoop(dst)) filterIt.push_back(dst); @@ -506,7 +505,7 @@ } } } - return env.topSortSize() == numLoops; + return env.topSortSize() == env.merger().getNumLoops(); } /// Helper method to add all constraints from the indices in one affine @@ -790,10 +789,10 @@ // TODO: Do we really need this? if (includesUndef(mask)) { const TensorId tid = env.makeTensorId(t.getOperandNumber()); - for (LoopId i = 0; i < numLoops; i++) { + for (const LoopId i : env.merger().getLoopIds()) { const auto dltI = env.dlt(tid, i); if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) { - for (LoopId j = 0; j < numLoops; j++) + for (const LoopId j : env.merger().getLoopIds()) if (isUndefDLT(env.dlt(tid, j))) { adjM[i][j] = true; inDegree[j]++; @@ -1537,7 +1536,8 @@ // starting from the first level as they do not depend on any thing. // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two // levels can be determined before loops. - for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++) + const TensorId hi = env.makeTensorId(env.op().getNumDpsInputs()); + for (const TensorId tid : tensor_id::Range(0, hi)) genConstantDenseAddressFromLevel(env, rewriter, tid, 0); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -438,8 +438,8 @@ const BitVector &bitsj = lat(j).bits; assert(bitsi.size() == bitsj.size()); if (bitsi.count() > bitsj.count()) { - for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) - if (bitsj[b] && !bitsi[b]) + for (const TensorLoopId b : tensor_loop_id::Range(bitsj)) + if (!bitsi[b]) return false; return true; } @@ -591,12 +591,11 @@ } bool Merger::hasAnySparse(const BitVector &bits) const { - for (TensorLoopId b = 0, be = bits.size(); b < be; b++) - if (bits[b]) { - const auto dlt = getDimLevelType(b); - if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) - return true; - } + for (const TensorLoopId b : tensor_loop_id::Range(bits)) { + const auto dlt = getDimLevelType(b); + if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) + return true; + } return false; } @@ -818,16 +817,14 @@ } void Merger::dumpBits(const BitVector &bits) const { - for (TensorLoopId b = 0, be = bits.size(); b < be; b++) { - if (bits[b]) { - const TensorId t = tensor(b); - const LoopId i = loop(b); - const auto dlt = lvlTypes[t][i]; - if (isLvlWithNonTrivialIdxExp(b)) - llvm::dbgs() << " DEP_" << t << "_" << i; - else - llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); - } + for (const TensorLoopId b : tensor_loop_id::Range(bits)) { + const TensorId t = tensor(b); + const LoopId i = loop(b); + const auto dlt = lvlTypes[t][i]; + if (isLvlWithNonTrivialIdxExp(b)) + llvm::dbgs() << " DEP_" << t << "_" << i; + else + llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt); } }