diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -18,7 +18,9 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator.h" #include +#include namespace mlir { class ArrayAttr; @@ -195,6 +197,23 @@ // Permutation utils. //===----------------------------------------------------------------------===// +template +SmallVector applyPermutation(ArrayRef input, + ArrayRef permutation) { + assert(input.size() == permutation.size() && + "expected input rank to equal permutation rank"); + auto permutationRange = llvm::map_range( + llvm::seq(0, input.size()), + [&](int64_t idx) -> T { return input[permutation[idx]]; }); + return llvm::to_vector(permutationRange); +} + +template +SmallVector applyPermutation(const SmallVectorImpl &input, + ArrayRef permutation) { + return applyPermutation(ArrayRef(input), permutation); +} + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation @@ -203,10 +222,7 @@ template void applyPermutationToVector(SmallVector &inVec, ArrayRef permutation) { - SmallVector auxVec(inVec.size()); - for (const auto &en : enumerate(permutation)) - auxVec[en.index()] = inVec[en.value()]; - inVec = auxVec; + inVec = applyPermutation(inVec, permutation); } /// Helper method to apply to inverse a permutation. @@ -239,6 +255,138 @@ computeLinearIndex(OpFoldResult sourceOffset, ArrayRef strides, ArrayRef indices); +//===----------------------------------------------------------------------===// +// Utilities for decomposing larger shapes +//===----------------------------------------------------------------------===// + +namespace detail { +/// Encapsulates the set of parameters that are used to make tile offset +/// calculations in the TileOffsetRangeIterator. +class TileOffsetRangeImpl { +public: + TileOffsetRangeImpl(ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder); + + int64_t getMaxLinearIndex() const { return maxLinearIndex; } + + SmallVector getStaticTileOffsets(int64_t linearIndex) const; + + SmallVector getDynamicTileOffsets(AffineExpr linearIndex) const; + + template + SmallVector getTileOffsets(T linearIndex) const { + if constexpr (std::is_same_v) + return getStaticTileOffsets(linearIndex); + else + return getDynamicTileOffsets(linearIndex); + } + +private: + /// The sub-shape that divides the larger outer shape (which is provided to + /// the constructor). + SmallVector tileShape; + /// The inverse permutation to the `loopOrder` permutation provided in the + /// constructor. + SmallVector inverseLoopOrder; + /// The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`. + SmallVector sliceStrides; + /// The maximum linear index in the iteration space given by basis 'div(shape, + /// tileShape)'. + int64_t maxLinearIndex; +}; + +/// The STL-style iterator implementation for StaticTileOffsetRange. +template +class TileOffsetRangeIterator + : public llvm::iterator_facade_base, + std::forward_iterator_tag, + SmallVector> { +public: + TileOffsetRangeIterator(const TileOffsetRangeImpl ¶ms, ElementType index) + : params(params), index(index) {} + + void operator++() { incrementIndex(1); } + TileOffsetRangeIterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + bool operator==(const TileOffsetRangeIterator &other) const { + return index == other.index; + } + bool operator!=(const TileOffsetRangeIterator &other) const { + return index != other.index; + } + + SmallVector operator*() const { + return params.getTileOffsets(index); + } + void operator+=(int64_t offset) { incrementIndex(offset); } + +private: + void incrementIndex(int64_t offset) { index = index + offset; } + const TileOffsetRangeImpl params; + int64_t index; +}; +} // namespace detail + +/// A range-style iterator that allows for iterating over the offsets of all +/// potential tiles of size `tileShape` within the larger shape `shape`, using +/// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of +/// unrolling by numbering the dimensions in order from "outer most for loop" +/// (slowest changing) to "inner most for loop" (fastest changing). +/// +/// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and +/// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets: +/// +/// ``` +/// {0, 0, 0}, {0, 10, 0}, {5, 0, 0}, {5, 10, 0}, {0, 0, 15}, +/// {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15} +/// ``` +/// +/// This is useful in contexts where a vector computation over a larger shape +/// needs to be unrolled to a set of operations on subsets of the original +/// operands, such as during the "vector unrolling" transformations. +/// +/// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a +/// If the rank of `tileShape` is smaller than `shape`, then `tileShape` +/// elements correspond to the trailing dimensions of `shape`, and the leading +/// dimensions are considered untiled and `tileShape` is effectively prepended +/// with the leading dims of `shape`. +class StaticTileOffsetRange { +public: + using IteratorTy = detail::TileOffsetRangeIterator; + using ParamsTy = detail::TileOffsetRangeImpl; + + StaticTileOffsetRange(ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder) + : params(shape, tileShape, loopOrder), beginValue(params, 0), + pastEndValue(params, params.getMaxLinearIndex()) { + assert(shape.size() >= tileShape.size()); + assert(loopOrder.size() == shape.size()); + } + + /// Create the range with identity loop order. + StaticTileOffsetRange(ArrayRef shape, ArrayRef tileShape) + : params(shape, tileShape, + llvm::to_vector(llvm::seq(0, shape.size()))), + beginValue(params, 0), + pastEndValue(params, params.getMaxLinearIndex()) { + assert(shape.size() >= tileShape.size()); + } + + IteratorTy begin() const { return beginValue; } + IteratorTy end() const { return pastEndValue; } + + /// Returns the total number of tiles that fit in the larger shape. + size_t size() const { return params.getMaxLinearIndex(); } + +private: + const ParamsTy params; + IteratorTy beginValue; + IteratorTy pastEndValue; +}; } // namespace mlir #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -17,6 +17,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include #include @@ -250,6 +251,8 @@ AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context); AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context); AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context); +SmallVector getAffineConstantExprs(ArrayRef constants, + MLIRContext *context); AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -181,9 +181,8 @@ AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, ArrayRef basis) { - SmallVector basisExprs = llvm::to_vector(llvm::map_range( - basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); - return linearize(ctx, offsets, basisExprs); + + return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx)); } SmallVector mlir::delinearize(AffineExpr linearIndex, @@ -196,9 +195,7 @@ SmallVector mlir::delinearize(AffineExpr linearIndex, ArrayRef strides) { MLIRContext *ctx = linearIndex.getContext(); - SmallVector basisExprs = llvm::to_vector(llvm::map_range( - strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); - return delinearize(linearIndex, ArrayRef{basisExprs}); + return delinearize(linearIndex, getAffineConstantExprs(strides, ctx)); } //===----------------------------------------------------------------------===// @@ -302,3 +299,56 @@ return {expr, values}; } + +//===----------------------------------------------------------------------===// +// TileOffsetRange +//===----------------------------------------------------------------------===// + +/// Apply left-padding by 1 to the tile shape if required. +static SmallVector padTileShapeToSize(ArrayRef tileShape, + unsigned paddedSize) { + assert(tileShape.size() <= paddedSize && + "expected tileShape to <= paddedSize"); + if (tileShape.size() == paddedSize) + return to_vector(tileShape); + SmallVector result(paddedSize - tileShape.size(), 1); + llvm::append_range(result, tileShape); + return result; +} + +mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( + ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder) + : tileShape(padTileShapeToSize(tileShape, shape.size())), + inverseLoopOrder(invertPermutationVector(loopOrder)), + sliceStrides(shape.size()) { + // Divide the shape by the tile shape. + std::optional> shapeRatio = + mlir::computeShapeRatio(shape, tileShape); + assert(shapeRatio && shapeRatio->size() == shape.size() && + "target shape does not evenly divide the original shape"); + assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && + "expected loop order to be a permutation of rank equal to outer " + "shape"); + + maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio); + mlir::applyPermutationToVector(*shapeRatio, loopOrder); + sliceStrides = mlir::computeStrides(*shapeRatio); +} + +SmallVector mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( + int64_t linearIndex) const { + SmallVector tileCoords = applyPermutation( + delinearize(linearIndex, sliceStrides), inverseLoopOrder); + return computeElementwiseMul(tileCoords, tileShape); +} + +SmallVector +mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( + AffineExpr linearIndex) const { + MLIRContext *ctx = linearIndex.getContext(); + SmallVector tileCoords = applyPermutation( + delinearize(linearIndex, sliceStrides), inverseLoopOrder); + return mlir::computeElementwiseMul(tileCoords, + getAffineConstantExprs(tileShape, ctx)); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -29,77 +29,6 @@ using namespace mlir; using namespace mlir::vector; -/// During unrolling from `originalShape` to `targetShape` return the offset for -/// the slice `index`. -static SmallVector getVectorOffset(ArrayRef ratioStrides, - int64_t index, - ArrayRef targetShape) { - return computeElementwiseMul(delinearize(index, ratioStrides), targetShape); -} - -/// A functor that accomplishes the same thing as `getVectorOffset` but -/// allows for reordering the traversal of the dimensions. The order of -/// traversal is given in "for loop order" (outer to inner). -namespace { -class DecomposeShapeIterator { -private: - SmallVector vectorShape; - SmallVector loopOrder; - SmallVector sliceStrides; - int64_t maxIndexVal{1}; - -public: - DecomposeShapeIterator(ArrayRef originalShape, - ArrayRef targetShape, - ArrayRef loopOrder) - : vectorShape(targetShape.begin(), targetShape.end()), - loopOrder(loopOrder.begin(), loopOrder.end()), - sliceStrides(originalShape.size()) { - assert(originalShape.size() >= targetShape.size()); - assert(loopOrder.size() == originalShape.size()); - - // Compute the count for each dimension. - auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape); - assert(maybeShapeRatio && "Shape does not evenly divide"); - // Pad `sliceDimCounts` with leading 1s so that all sizes match. - SmallVector sliceDimCounts = *maybeShapeRatio; - maxIndexVal = computeMaxLinearIndex(sliceDimCounts); - - // Reversing "loop order" gives dimensions from fastest varying to slowest - // varying (smallest stride to largest stride). - int64_t accum = 1; - for (auto idx : llvm::reverse(loopOrder)) { - sliceStrides[idx] = accum; - accum *= sliceDimCounts[idx]; - } - } - - // Turn the linear index into a d-tuple based on units of vectors of size - // `vectorShape`. The linear index is assumed to represent traversal of the - // dimensions based on `order`. - SmallVector delinearize(int64_t index) const { - // Traverse in for loop order (largest stride to smallest stride). - SmallVector vectorOffsets(sliceStrides.size()); - for (auto idx : loopOrder) { - vectorOffsets[idx] = index / sliceStrides[idx]; - index %= sliceStrides[idx]; - } - return vectorOffsets; - } - - int64_t maxIndex() const { return maxIndexVal; } - - /// Return the offset within d-tuple based on the ordering given by - /// `loopOrder`. - SmallVector getVectorOffset(int64_t index) const { - SmallVector vectorOffsets = delinearize(index); - SmallVector elementOffsets = - computeElementwiseMul(vectorShape, vectorOffsets); - return elementOffsets; - } -}; -} // namespace - /// Compute the indices of the slice `index` for a tranfer op. static SmallVector sliceTransferIndices(ArrayRef elementOffsets, ArrayRef indices, @@ -232,13 +161,10 @@ VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), readOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalSize.size(), readOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); @@ -283,14 +209,11 @@ ArrayRef originalSize = sourceVectorType.getShape(); SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalSize.size(), writeOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); Value resultTensor; - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = @@ -355,11 +278,9 @@ SmallVector loopOrder = getUnrollOrder( contractOp.getIteratorTypes().size(), contractOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - const int64_t sliceCount = indexToOffsets.maxIndex(); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = indexToOffsets.getVectorOffset(i); + + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to compute the new shape of each operand and extract the slice. @@ -439,22 +360,16 @@ if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; - // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = reductionOp.getLoc(); // Stride of the ratios, this gives us the offsets of sliceCount in a basis // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); - + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector operands; SmallVector operandStrides(offsets.size(), 1); Value slicedOperand = rewriter.create( @@ -520,8 +435,6 @@ auto dstVecType = cast(op->getResult(0).getType()); SmallVector originalSize = *cast(op).getShapeForUnroll(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = op->getLoc(); // Prepare the result vector. Value result = rewriter.create( @@ -530,12 +443,9 @@ VectorType newVecType = VectorType::get(*targetShape, dstVecType.getElementType()); - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); + // Create the unrolled computation. + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast(operand.get().getType()); @@ -574,19 +484,12 @@ if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); - auto ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = ratio[0]; // Create unrolled vector reduction. Location loc = reductionOp.getLoc(); Value accumulator = nullptr; - - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; ++i) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector strides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getVector(), offsets, *targetShape, strides); @@ -630,20 +533,16 @@ SmallVector strides(targetShape->size(), 1); Location loc = transposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = computeMaxLinearIndex(ratio); + // Prepare the result vector; Value result = rewriter.create( loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); SmallVector permutation; transposeOp.getTransp(permutation); - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector elementOffsets = - getVectorOffset(ratioStrides, i, *targetShape); + // Unroll the computation. + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector permutedOffsets(elementOffsets.size()); SmallVector permutedShape(elementOffsets.size()); // Compute the source offsets and shape. @@ -694,13 +593,11 @@ SmallVector loopOrder = getUnrollOrder(originalSize.size(), gatherOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) { + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { // To get the unrolled gather, extract the same slice based on the // decomposed shape from each of the index, mask, and pass-through // vectors. - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); Value indexSubVec = rewriter.create( loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.create( diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -533,6 +533,14 @@ return uniquer.get(assignCtx, constant); } +SmallVector +mlir::getAffineConstantExprs(ArrayRef constants, + MLIRContext *context) { + return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) { + return getAffineConstantExpr(constant, context); + })); +} + /// Simplify add expression. Return nullptr if it can't be simplified. static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = lhs.dyn_cast(); diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt --- a/mlir/unittests/Dialect/Utils/CMakeLists.txt +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRDialectUtilsTests StructuredOpsUtilsTest.cpp + IndexingUtilsTest.cpp ) target_link_libraries(MLIRDialectUtilsTests PRIVATE diff --git a/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp b/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp @@ -0,0 +1,71 @@ +//===- IndexingUtilsTest.cpp - IndexingUtils unit tests -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "gtest/gtest.h" + +using namespace mlir; + +TEST(StaticTileOffsetRange, checkIteratorCanonicalOrder) { + // Tile <4x8> by <2x4> with canonical row-major order. + std::vector> expected = {{0, 0}, {0, 4}, {2, 0}, {2, 4}}; + for (auto [idx, tileOffset] : + llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}))) + EXPECT_EQ(tileOffset, expected[idx]); + + // Check the constructor for default order and test use with zip iterator. + for (auto [tileOffset, tileOffsetDefault] : + llvm::zip(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}), + StaticTileOffsetRange({4, 8}, {2, 4}))) + EXPECT_EQ(tileOffset, tileOffsetDefault); +} + +TEST(StaticTileOffsetRange, checkIteratorRowMajorOrder) { + // Tile <4x8> by <2x4> with canonical row-major order. + std::vector> expected = {{0, 0}, {2, 0}, {0, 4}, {2, 4}}; + for (auto [idx, tileOffset] : + llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {1, 0}))) + EXPECT_EQ(tileOffset, expected[idx]); +} + +TEST(StaticTileOffsetRange, checkLeadingOneFill) { + // Tile <4x8> by <4>. A smaller tile shape gets right-aligned to the shape. + for (auto [idx, tileOffset] : + llvm::enumerate(StaticTileOffsetRange({4, 8}, {4}))) { + SmallVector expected = {static_cast(idx) / 2, + static_cast(idx) % 2 * 4}; + EXPECT_EQ(tileOffset, expected); + } + for (auto [idx, tileOffset] : + llvm::enumerate(StaticTileOffsetRange({1, 4, 8}, {4}, {2, 1, 0}))) { + SmallVector expected = {0, static_cast(idx) % 4, + (static_cast(idx) / 4) * 4}; + EXPECT_EQ(tileOffset, expected); + } +} + +TEST(StaticTileOffsetRange, checkIterator3DPermutation) { + // Tile <8x4x2> by <4x2x1> with permutation [1, 0, 2] + for (auto [idx, tileOffset] : llvm::enumerate( + StaticTileOffsetRange({8, 4, 2}, {4, 2, 1}, {1, 0, 2}))) { + SmallVector expected = {((static_cast(idx) / 2) % 2) * 4, + ((static_cast(idx) / 4) % 2) * 2, + static_cast(idx) % 2}; + EXPECT_EQ(tileOffset, expected); + } + + // Tile <10x20x30> by <5x10x16> with permutation [2, 0, 1] + for (auto [idx, tileOffset] : llvm::enumerate( + StaticTileOffsetRange({10, 20, 30}, {5, 10, 15}, {2, 0, 1}))) { + SmallVector expected = {((static_cast(idx) / 2) % 2) * 5, + (static_cast(idx) % 2) * 10, + (static_cast(idx) / 4) % 2 * 15}; + EXPECT_EQ(tileOffset, expected); + } +}