diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -178,6 +178,23 @@ // Permutation utils. //===----------------------------------------------------------------------===// +template +SmallVector applyPermutation(ArrayRef input, + ArrayRef permutation) { + assert(input.size() == permutation.size() && + "expected input rank to equal permutation rank"); + auto permutationRange = llvm::map_range( + llvm::seq(0, input.size()), + [&](int64_t idx) -> T { return input[permutation[idx]]; }); + return llvm::to_vector(permutationRange); +} + +template +SmallVector applyPermutation(const SmallVectorImpl &input, + ArrayRef permutation) { + return applyPermutation(ArrayRef(input), permutation); +} + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation @@ -186,10 +203,7 @@ template void applyPermutationToVector(SmallVector &inVec, ArrayRef permutation) { - SmallVector auxVec(inVec.size()); - for (const auto &en : enumerate(permutation)) - auxVec[en.index()] = inVec[en.value()]; - inVec = auxVec; + inVec = applyPermutation(inVec, permutation); } /// Helper method to apply to inverse a permutation. @@ -212,6 +226,127 @@ SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); +//===----------------------------------------------------------------------===// +// Utilities for decomposing larger shapes +//===----------------------------------------------------------------------===// + +namespace detail { +/// Encapsulates the set of parameters that are used to make tile offset +/// calculations in the TileOffsetRangeIterator. +class TileOffsetRangeImpl { +public: + TileOffsetRangeImpl(ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder); + + int64_t getMaxIndexVal() const { return maxIndexVal; } + + SmallVector getStaticTileOffsets(int64_t linearIndex) const; + + SmallVector getDynamicTileOffsets(AffineExpr linearIndex) const; + + template + SmallVector getTileOffsets(T linearIndex) const { + if constexpr (std::is_same_v) + return getStaticTileOffsets(linearIndex); + else + return getDynamicTileOffsets(linearIndex); + } + +private: + SmallVector tileShape; + SmallVector inverseLoopOrder; + SmallVector sliceStrides; + int64_t maxIndexVal; +}; + +/// The STL-style iterator implementation for StaticTileOffsetRange. +template +class TileOffsetRangeIterator { +public: + using ParamsTy = TileOffsetRangeImpl; + using ValueTy = SmallVector; + using iterator_category = std::random_access_iterator_tag; + + void operator++() { incrementIndex(1); } + TileOffsetRangeIterator operator++(int) { + const auto copy = *this; + ++*this; + return copy; + } + + TileOffsetRangeIterator(const TileOffsetRangeImpl ¶ms, ElementType index) + : params(params), index(index) {} + + bool operator==(const TileOffsetRangeIterator &other) const { + return index == other.index; + } + bool operator!=(const TileOffsetRangeIterator &other) const { + return index != other.index; + } + + ValueTy operator*() const { return params.getTileOffsets(index); } + void operator+=(int64_t offset) { incrementIndex(offset); } + +private: + void incrementIndex(int64_t offset) { index = index + offset; } + const ParamsTy params; + int64_t index; +}; +} // namespace detail + +/// A range-style iterator that allows for iterating over the offsets of all +/// potential tiles of size `tileShape` within the larger shape `shape`, using +/// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of +/// unrolling by numbering the dimensions in order from "outer most for loop" +/// (slowest changing) to "inner most for loop" (fastest changing). +/// +/// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and +/// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets: +/// +/// ``` {0, 0, 0}, {0, 10, 0}, {5, 0, 0}, {5, 10, 0}, {0, 0, 15}, {0, 10, +/// 15}, {5, 0, 15}, {5, 10, 15} ``` +/// +/// This is useful in contexts where a vector computation over a larger shape +/// needs to be unrolled to a set of operations on subsets of the original +/// operands, such as during the "vector unrolling" transformations. +/// +/// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a +/// If the rank of `tileShape` is smaller than `shape`, then `tileShape` +/// elements correspond to the trailing dimensions of `shape`, and the leading +/// dimensions are considered untiled and `tileShape` is effectively prepended +/// with the leading dims of `shape`. +class StaticTileOffsetRange { +public: + using IteratorTy = detail::TileOffsetRangeIterator; + using ParamsTy = detail::TileOffsetRangeImpl; + + StaticTileOffsetRange(ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder) + : params(shape, tileShape, loopOrder), beginValue(params, 0), + pastEndValue(params, params.getMaxIndexVal()) { + assert(shape.size() >= tileShape.size()); + assert(loopOrder.size() == shape.size()); + } + + /// Create the range with identity loop order. + StaticTileOffsetRange(ArrayRef shape, ArrayRef tileShape) + : params(shape, tileShape, + llvm::to_vector(llvm::seq(0, shape.size()))), + beginValue(params, 0), pastEndValue(params, params.getMaxIndexVal()) { + assert(shape.size() >= tileShape.size()); + } + + IteratorTy begin() const { return beginValue; } + IteratorTy end() const { return pastEndValue; } + + /// Returns the total number of tiles that fit in the larger shape. + size_t size() const { return params.getMaxIndexVal(); } + +private: + const ParamsTy params; + IteratorTy beginValue; + IteratorTy pastEndValue; +}; } // namespace mlir #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -17,6 +17,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include #include @@ -251,6 +252,8 @@ AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context); AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context); AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context); +SmallVector getAffineConstantExprs(ArrayRef constants, + MLIRContext *context); AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs); diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -166,9 +166,8 @@ AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, ArrayRef basis) { - SmallVector basisExprs = llvm::to_vector(llvm::map_range( - basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); - return linearize(ctx, offsets, basisExprs); + + return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx)); } SmallVector mlir::delinearize(AffineExpr linearIndex, @@ -181,9 +180,7 @@ SmallVector mlir::delinearize(AffineExpr linearIndex, ArrayRef strides) { MLIRContext *ctx = linearIndex.getContext(); - SmallVector basisExprs = llvm::to_vector(llvm::map_range( - strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); - return delinearize(linearIndex, ArrayRef{basisExprs}); + return delinearize(linearIndex, getAffineConstantExprs(strides, ctx)); } //===----------------------------------------------------------------------===// @@ -246,3 +243,52 @@ res.push_back((*it).getValue().getSExtValue()); return res; } + +//===----------------------------------------------------------------------===// +// TileOffsetRange +//===----------------------------------------------------------------------===// + +mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( + ArrayRef shape, ArrayRef tileShape, + ArrayRef loopOrder) + : tileShape(tileShape), + inverseLoopOrder(invertPermutationVector(loopOrder)), + sliceStrides(shape.size()) { + // Compute the count for each dimension. + auto maybeShapeRatio = mlir::computeShapeRatio(shape, tileShape); + assert(maybeShapeRatio && + "target shape does not evenly divide the original shape"); + assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && + "expected loop order to be a permutation of rank equal to outer " + "shape"); + + // Pad `sliceDimCounts` with leading 1s so that all sizes match. + SmallVector sliceDimCounts = *maybeShapeRatio; + maxIndexVal = mlir::computeMaxLinearIndex(sliceDimCounts); + + // Reversing "loop order" gives dimensions from fastest varying to + // slowest varying (smallest stride to largest stride). + int64_t accum = 1; + for (auto idx : llvm::reverse(loopOrder)) { + sliceStrides[idx] = accum; + accum *= sliceDimCounts[idx]; + } + mlir::applyPermutationToVector(sliceStrides, loopOrder); +} + +SmallVector mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( + int64_t linearIndex) const { + SmallVector tileCoords = applyPermutation( + delinearize(linearIndex, sliceStrides), inverseLoopOrder); + return computeElementwiseMul(tileCoords, tileShape); +} + +SmallVector +mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( + AffineExpr linearIndex) const { + MLIRContext *ctx = linearIndex.getContext(); + SmallVector tileCoords = applyPermutation( + delinearize(linearIndex, sliceStrides), inverseLoopOrder); + return mlir::computeElementwiseMul(tileCoords, + getAffineConstantExprs(tileShape, ctx)); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -26,77 +26,6 @@ using namespace mlir; using namespace mlir::vector; -/// During unrolling from `originalShape` to `targetShape` return the offset for -/// the slice `index`. -static SmallVector getVectorOffset(ArrayRef ratioStrides, - int64_t index, - ArrayRef targetShape) { - return computeElementwiseMul(delinearize(index, ratioStrides), targetShape); -} - -/// A functor that accomplishes the same thing as `getVectorOffset` but -/// allows for reordering the traversal of the dimensions. The order of -/// traversal is given in "for loop order" (outer to inner). -namespace { -class DecomposeShapeIterator { -private: - SmallVector vectorShape; - SmallVector loopOrder; - SmallVector sliceStrides; - int64_t maxIndexVal{1}; - -public: - DecomposeShapeIterator(ArrayRef originalShape, - ArrayRef targetShape, - ArrayRef loopOrder) - : vectorShape(targetShape.begin(), targetShape.end()), - loopOrder(loopOrder.begin(), loopOrder.end()), - sliceStrides(originalShape.size()) { - assert(originalShape.size() >= targetShape.size()); - assert(loopOrder.size() == originalShape.size()); - - // Compute the count for each dimension. - auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape); - assert(maybeShapeRatio && "Shape does not evenly divide"); - // Pad `sliceDimCounts` with leading 1s so that all sizes match. - SmallVector sliceDimCounts = *maybeShapeRatio; - maxIndexVal = computeMaxLinearIndex(sliceDimCounts); - - // Reversing "loop order" gives dimensions from fastest varying to slowest - // varying (smallest stride to largest stride). - int64_t accum = 1; - for (auto idx : llvm::reverse(loopOrder)) { - sliceStrides[idx] = accum; - accum *= sliceDimCounts[idx]; - } - } - - // Turn the linear index into a d-tuple based on units of vectors of size - // `vectorShape`. The linear index is assumed to represent traversal of the - // dimensions based on `order`. - SmallVector delinearize(int64_t index) const { - // Traverse in for loop order (largest stride to smallest stride). - SmallVector vectorOffsets(sliceStrides.size()); - for (auto idx : loopOrder) { - vectorOffsets[idx] = index / sliceStrides[idx]; - index %= sliceStrides[idx]; - } - return vectorOffsets; - } - - int64_t maxIndex() const { return maxIndexVal; } - - /// Return the offset within d-tuple based on the ordering given by - /// `loopOrder`. - SmallVector getVectorOffset(int64_t index) const { - SmallVector vectorOffsets = delinearize(index); - SmallVector elementOffsets = - computeElementwiseMul(vectorShape, vectorOffsets); - return elementOffsets; - } -}; -} // namespace - /// Compute the indices of the slice `index` for a tranfer op. static SmallVector sliceTransferIndices(ArrayRef elementOffsets, ArrayRef indices, @@ -206,13 +135,10 @@ VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), readOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalSize.size(), readOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); @@ -257,14 +183,11 @@ ArrayRef originalSize = sourceVectorType.getShape(); SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); - SmallVector loopOrder = getUnrollOrder(originalSize.size(), writeOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); Value resultTensor; - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = @@ -329,11 +252,9 @@ SmallVector loopOrder = getUnrollOrder( contractOp.getIteratorTypes().size(), contractOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - const int64_t sliceCount = indexToOffsets.maxIndex(); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = indexToOffsets.getVectorOffset(i); + + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to compute the new shape of each operand and extract the slice. @@ -413,22 +334,16 @@ if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; - // Compute shape ratio of 'shape' and 'sizes'. - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = reductionOp.getLoc(); // Stride of the ratios, this gives us the offsets of sliceCount in a basis // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); - + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector operands; SmallVector operandStrides(offsets.size(), 1); Value slicedOperand = rewriter.create( @@ -494,8 +409,6 @@ auto dstVecType = op->getResult(0).getType().cast(); SmallVector originalSize = *cast(op).getShapeForUnroll(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = op->getLoc(); // Prepare the result vector. Value result = rewriter.create( @@ -504,12 +417,9 @@ VectorType newVecType = VectorType::get(*targetShape, dstVecType.getElementType()); - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); + // Create the unrolled computation. + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = operand.get().getType().template dyn_cast(); @@ -548,19 +458,12 @@ if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); - auto ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = ratio[0]; // Create unrolled vector reduction. Location loc = reductionOp.getLoc(); Value accumulator = nullptr; - - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; ++i) { - SmallVector offsets = - getVectorOffset(ratioStrides, i, *targetShape); + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector strides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getVector(), offsets, *targetShape, strides); @@ -604,20 +507,16 @@ SmallVector strides(targetShape->size(), 1); Location loc = transposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); - SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); - int64_t sliceCount = computeMaxLinearIndex(ratio); + // Prepare the result vector; Value result = rewriter.create( loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); SmallVector permutation; transposeOp.getTransp(permutation); - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - auto ratioStrides = computeStrides(ratio); - for (int64_t i = 0; i < sliceCount; i++) { - SmallVector elementOffsets = - getVectorOffset(ratioStrides, i, *targetShape); + // Unroll the computation. + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector permutedOffsets(elementOffsets.size()); SmallVector permutedShape(elementOffsets.size()); // Compute the source offsets and shape. @@ -668,13 +567,11 @@ SmallVector loopOrder = getUnrollOrder(originalSize.size(), gatherOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) { + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { // To get the unrolled gather, extract the same slice based on the // decomposed shape from each of the index, mask, and pass-through // vectors. - SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); Value indexSubVec = rewriter.create( loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.create( diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -533,6 +533,14 @@ return uniquer.get(assignCtx, constant); } +SmallVector +mlir::getAffineConstantExprs(ArrayRef constants, + MLIRContext *context) { + return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) { + return getAffineConstantExpr(constant, context); + })); +} + /// Simplify add expression. Return nullptr if it can't be simplified. static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = lhs.dyn_cast();