diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -23,35 +23,68 @@ namespace mlir { class ArrayAttr; -/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. -int64_t linearize(ArrayRef offsets, ArrayRef basis); - -/// Given the strides together with a linear index in the dimension -/// space, returns the vector-space offsets in each dimension for a -/// de-linearized index. -SmallVector delinearize(ArrayRef strides, - int64_t linearIndex); +//===----------------------------------------------------------------------===// +// Utils that operate on static integer values. +//===----------------------------------------------------------------------===// -/// Given a set of sizes, compute and return the strides (i.e. the number of -/// linear incides to skip along the (k-1) most minor dimensions to get the next -/// k-slice). This is also the basis that one can use to linearize an n-D offset -/// confined to `[0 .. sizes]`. -SmallVector computeStrides(ArrayRef sizes); +/// Given a set of sizes, return the suffix product. +/// +/// When applied to slicing, this is the calculation needed to derive the +/// strides (i.e. the number of linear indices to skip along the (k-1) most +/// minor dimensions to get the next k-slice). +/// +/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`. +/// +/// Assuming `sizes` is `[s0, .. sn]`, return the vector +/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. +/// +/// `sizes` elements are asserted to be non-negative. +/// +/// Return an empty vector if `sizes` is empty. +SmallVector computeSuffixProduct(ArrayRef sizes); +inline SmallVector computeStrides(ArrayRef sizes) { + return computeSuffixProduct(sizes); +} -/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise. +/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise. +/// +/// Return an empty vector if `v1` and `v2` are empty. SmallVector computeElementwiseMul(ArrayRef v1, ArrayRef v2); -/// Compute and return the multi-dimensional integral ratio of `subShape` to -/// the trailing dimensions of `shape`. This represents how many times -/// `subShape` fits within `shape`. -/// If integral division is not possible, return std::nullopt. +/// Return the number of elements of basis (i.e. the max linear index). +/// Return `0` if `basis` is empty. +/// +/// `basis` elements are asserted to be non-negative. +/// +/// Return `0` if `basis` is empty. +int64_t computeMaxLinearIndex(ArrayRef basis); + +/// Return the linearized index of 'offsets' w.r.t. 'basis'. +/// +/// `basis` elements are asserted to be non-negative. +int64_t linearize(ArrayRef offsets, ArrayRef basis); + +/// Given the strides together with a linear index in the dimension space, +/// return the vector-space offsets in each dimension for a de-linearized index. +/// `strides` elements are asserted to be non-negative. +/// +/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the +/// vector of int64_t +/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]` +SmallVector delinearize(int64_t linearIndex, + ArrayRef strides); + +/// Return the multi-dimensional integral ratio of `subShape` to the trailing +/// dimensions of `shape`. This represents how many times `subShape` fits +/// within `shape`. If integral division is not possible, return std::nullopt. /// The trailing `subShape.size()` entries of both shapes are assumed (and -/// enforced) to only contain noonnegative values. +/// enforced) to only contain non-negative values. /// /// Examples: /// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}. -/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has higher +/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has +/// higher /// rank). /// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is /// derived as {42(leading shape dim), 2/2, 10/5, 32/2}. @@ -60,14 +93,96 @@ std::optional> computeShapeRatio(ArrayRef shape, ArrayRef subShape); +//===----------------------------------------------------------------------===// +// Utils that operate on AffineExpr. +//===----------------------------------------------------------------------===// + +/// Given a set of sizes, return the suffix product. +/// +/// When applied to slicing, this is the calculation needed to derive the +/// strides (i.e. the number of linear indices to skip along the (k-1) most +/// minor dimensions to get the next k-slice). +/// +/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`. +/// +/// Assuming `sizes` is `[s0, .. sn]`, return the vector +/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. +/// +/// It is the caller's responsibility to pass proper AffineExpr kind that +/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide +/// by an AffineDimExpr). +/// +/// `sizes` elements are expected to bind to non-negative values. +/// +/// Return an empty vector if `sizes` is empty. +SmallVector computeSuffixProduct(ArrayRef sizes); +inline SmallVector computeStrides(ArrayRef sizes) { + return computeSuffixProduct(sizes); +} + +/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise. +/// +/// It is the caller's responsibility to pass proper AffineExpr kind that +/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide +/// by an AffineDimExpr). +/// +/// Return an empty vector if `v1` and `v2` are empty. +SmallVector computeElementwiseMul(ArrayRef v1, + ArrayRef v2); + /// Return the number of elements of basis (i.e. the max linear index). /// Return `0` if `basis` is empty. -int64_t computeMaxLinearIndex(ArrayRef basis); +/// +/// It is the caller's responsibility to pass proper AffineExpr kind that +/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide +/// by an AffineDimExpr). +/// +/// `basis` elements are expected to bind to non-negative values. +/// +/// Return the `0` AffineConstantExpr if `basis` is empty. +AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef basis); + +/// Return the linearized index of 'offsets' w.r.t. 'basis'. +/// +/// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the +/// AffineExpr `o0 * b0 + .. + on * bn`. +/// +/// It is the caller's responsibility to pass proper AffineExpr kind that result +/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an +/// AffineDimExpr). +/// +/// `basis` elements are expected to bind to non-negative values. +AffineExpr linearize(MLIRContext *ctx, ArrayRef offsets, + ArrayRef basis); +AffineExpr linearize(MLIRContext *ctx, ArrayRef offsets, + ArrayRef basis); + +/// Given the strides together with a linear index in the dimension space, +/// return the vector-space offsets in each dimension for a de-linearized index. +/// +/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the +/// vector of AffineExpr +/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]` +/// +/// It is the caller's responsibility to pass proper AffineExpr kind that result +/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an +/// AffineDimExpr). +/// +/// `strides` elements are expected to bind to non-negative values. +SmallVector delinearize(AffineExpr linearIndex, + ArrayRef strides); +SmallVector delinearize(AffineExpr linearIndex, + ArrayRef strides); + +//===----------------------------------------------------------------------===// +// Permutation utils. +//===----------------------------------------------------------------------===// /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. -/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector -/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation +/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', +/// 'b']`. template void applyPermutationToVector(SmallVector &inVec, ArrayRef permutation) { @@ -83,18 +198,11 @@ /// Method to check if an interchange vector is a permutation. bool isPermutationVector(ArrayRef interchange); -/// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +/// Helper to return a subset of `arrayAttr` as a vector of int64_t. +// TODO: Port everything relevant to DenseArrayAttr and drop this util. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); -/// Computes and returns linearized affine expression w.r.t. `basis`. -mlir::AffineExpr getLinearAffineExpr(ArrayRef basis, mlir::Builder &b); - -/// Given the strides in the dimension space, returns the affine expressions for -/// vector-space offsets in each dimension for a de-linearized index. -SmallVector -getDelinearizedAffineExpr(ArrayRef strides, mlir::Builder &b); - } // namespace mlir #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -321,13 +321,6 @@ bindSymbols(ctx, exprs...); } -template -void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl &exprs) { - int idx = 0; - for (AffineExprTy &e : exprs) - e = getAffineSymbolExpr(idx++, ctx); -} - } // namespace detail /// Bind a list of AffineExpr references to DimExpr at positions: @@ -337,6 +330,13 @@ detail::bindDims<0>(ctx, exprs...); } +template +void bindDimsList(MLIRContext *ctx, SmallVectorImpl &exprs) { + int idx = 0; + for (AffineExprTy &e : exprs) + e = getAffineDimExpr(idx++, ctx); +} + /// Bind a list of AffineExpr references to SymbolExpr at positions: /// [0 .. sizeof...(exprs)] template @@ -344,6 +344,13 @@ detail::bindSymbols<0>(ctx, exprs...); } +template +void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl &exprs) { + int idx = 0; + for (AffineExprTy &e : exprs) + e = getAffineSymbolExpr(idx++, ctx); +} + } // namespace mlir namespace llvm { diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -103,7 +103,7 @@ loc, DenseElementsAttr::get(vecType, initValueAttr)); SmallVector strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { - SmallVector positions = delinearize(strides, linearIndex); + SmallVector positions = delinearize(linearIndex, strides); SmallVector operands; for (Value input : op->getOperands()) operands.push_back( diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -89,7 +89,7 @@ vecType, FloatAttr::get(vecType.getElementType(), 0.0))); SmallVector strides = computeStrides(shape); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { - SmallVector positions = delinearize(strides, linearIndex); + SmallVector positions = delinearize(linearIndex, strides); SmallVector operands; for (auto input : op->getOperands()) operands.push_back( diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -134,7 +134,7 @@ SmallVector results(maxIndex); for (int64_t i = 0; i < maxIndex; ++i) { - auto offsets = delinearize(strides, i); + auto offsets = delinearize(i, strides); SmallVector extracted(expandedOperands.size()); for (const auto &tuple : llvm::enumerate(expandedOperands)) @@ -152,7 +152,7 @@ for (int64_t i = 0; i < maxIndex; ++i) result = builder.create(results[i], result, - delinearize(strides, i)); + delinearize(i, strides)); // Reshape back to the original vector shape. return builder.create( diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -75,7 +75,7 @@ SmallVector values(2 * sourceRank + 1); SmallVector symbols(2 * sourceRank + 1); - detail::bindSymbolsList(rewriter.getContext(), symbols); + bindSymbolsList(rewriter.getContext(), symbols); AffineExpr expr = symbols.front(); values[0] = ShapedType::isDynamic(sourceOffset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) @@ -262,10 +262,9 @@ auto sourceType = source.getType().cast(); auto [strides, offset] = getStridesAndOffset(sourceType); - OpFoldResult origStride = - ShapedType::isDynamic(strides[groupId]) - ? origStrides[groupId] - : builder.getIndexAttr(strides[groupId]); + OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) + ? origStrides[groupId] + : builder.getIndexAttr(strides[groupId]); // Apply the original stride to all the strides. int64_t doneStrideIdx = 0; diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallBitVector.h" @@ -54,24 +55,26 @@ memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl &sourceIndices) { + MLIRContext *ctx = rewriter.getContext(); for (SmallVector groups : expandShapeOp.getReassociationIndices()) { assert(!groups.empty() && "association indices groups cannot be empty"); - unsigned groupSize = groups.size(); - SmallVector suffixProduct(groupSize); - // Calculate suffix product of dimension sizes for all dimensions of expand - // shape op result. - suffixProduct[groupSize - 1] = 1; - for (unsigned i = groupSize - 1; i > 0; i--) - suffixProduct[i - 1] = - suffixProduct[i] * - expandShapeOp.getType().cast().getDimSize(groups[i]); - SmallVector dynamicIndices(groupSize); - for (unsigned i = 0; i < groupSize; i++) - dynamicIndices[i] = indices[groups[i]]; + int64_t groupSize = groups.size(); + // Construct the expression for the index value w.r.t to expand shape op // source corresponding the indices wrt to expand shape op result. - AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter); + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + SmallVector dims(groupSize); + bindDimsList(ctx, dims); + AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + + /// Apply permutation and create AffineApplyOp.. + SmallVector dynamicIndices(groupSize); + for (int64_t i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; sourceIndices.push_back(rewriter.create( loc, AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr), @@ -98,35 +101,41 @@ memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl &sourceIndices) { - unsigned cnt = 0; + int64_t cnt = 0; SmallVector tmp(indices.size()); SmallVector dynamicIndices; for (SmallVector groups : collapseShapeOp.getReassociationIndices()) { assert(!groups.empty() && "association indices groups cannot be empty"); dynamicIndices.push_back(indices[cnt++]); - unsigned groupSize = groups.size(); - SmallVector suffixProduct(groupSize); + int64_t groupSize = groups.size(); + // Calculate suffix product for all collapse op source dimension sizes. - suffixProduct[groupSize - 1] = 1; - for (unsigned i = groupSize - 1; i > 0; i--) - suffixProduct[i - 1] = - suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + // Derive the index values along all dimensions of the source corresponding // to the index wrt to collapsed shape op output. - SmallVector srcIndexExpr = - getDelinearizedAffineExpr(suffixProduct, rewriter); - for (unsigned i = 0; i < groupSize; i++) + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = + delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) sourceIndices.push_back(rewriter.create( - loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]), + loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), dynamicIndices)); dynamicIndices.clear(); } if (collapseShapeOp.getReassociationIndices().empty()) { auto zeroAffineMap = rewriter.getConstantAffineMap(0); - unsigned srcRank = + int64_t srcRank = collapseShapeOp.getViewSource().getType().cast().getRank(); - for (unsigned i = 0; i < srcRank; i++) + for (int64_t i = 0; i < srcRank; i++) sourceIndices.push_back( rewriter.create(loc, zeroAffineMap, dynamicIndices)); } @@ -157,9 +166,9 @@ SmallVector useIndices; // Check if this is rank-reducing case. Then for every unit-dim size add a // zero to the indices. - unsigned resultDim = 0; + int64_t resultDim = 0; llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { + for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { if (unusedDims.test(dim)) useIndices.push_back(rewriter.create(loc, 0)); else @@ -171,7 +180,7 @@ for (auto index : llvm::seq(0, mixedOffsets.size())) { SmallVector dynamicOperands; AffineExpr expr = rewriter.getAffineDimExpr(0); - unsigned numSymbols = 0; + int64_t numSymbols = 0; dynamicOperands.push_back(useIndices[index]); // Multiply the stride; @@ -353,7 +362,7 @@ const SmallVector &indices, Location loc, PatternRewriter &rewriter) { SmallVector expandedIndices; - for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) + for (int64_t i = 0, e = affineMap.getNumResults(); i < e; i++) expandedIndices.push_back( rewriter.create(loc, affineMap.getSubMap({i}), indices)); return expandedIndices; diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -11,27 +11,100 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/STLExtras.h" #include #include using namespace mlir; -SmallVector mlir::computeStrides(ArrayRef sizes) { - SmallVector strides(sizes.size(), 1); +template +SmallVector computeSuffixProductImpl(ArrayRef sizes, + ExprType unit) { + if (sizes.empty()) + return {}; + SmallVector strides(sizes.size(), unit); for (int64_t r = strides.size() - 2; r >= 0; --r) strides[r] = strides[r + 1] * sizes[r + 1]; return strides; } -SmallVector mlir::computeElementwiseMul(ArrayRef v1, - ArrayRef v2) { - SmallVector result; - for (auto it : llvm::zip(v1, v2)) +template +SmallVector computeElementwiseMulImpl(ArrayRef v1, + ArrayRef v2) { + // Early exit if both are empty, let zip_equal fail if only 1 is empty. + if (v1.empty() && v2.empty()) + return {}; + SmallVector result; + for (auto it : llvm::zip_equal(v1, v2)) result.push_back(std::get<0>(it) * std::get<1>(it)); return result; } +template +ExprType linearizeImpl(ArrayRef offsets, ArrayRef basis, + ExprType zero) { + assert(offsets.size() == basis.size()); + ExprType linearIndex = zero; + for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) + linearIndex = linearIndex + offsets[idx] * basis[idx]; + return linearIndex; +} + +template +SmallVector delinearizeImpl(ExprType linearIndex, + ArrayRef strides, + DivOpTy divOp) { + int64_t rank = strides.size(); + SmallVector offsets(rank); + for (int64_t r = 0; r < rank; ++r) { + offsets[r] = divOp(linearIndex, strides[r]); + linearIndex = linearIndex % strides[r]; + } + return offsets; +} + +//===----------------------------------------------------------------------===// +// Utils that operate on static integer values. +//===----------------------------------------------------------------------===// + +SmallVector mlir::computeSuffixProduct(ArrayRef sizes) { + assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) && + "sizes must be nonnegative"); + int64_t unit = 1; + return ::computeSuffixProductImpl(sizes, unit); +} + +SmallVector mlir::computeElementwiseMul(ArrayRef v1, + ArrayRef v2) { + return computeElementwiseMulImpl(v1, v2); +} + +int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { + assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && + "basis must be nonnegative"); + if (basis.empty()) + return 0; + return std::accumulate(basis.begin(), basis.end(), 1, + std::multiplies()); +} + +int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { + assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && + "basis must be nonnegative"); + int64_t zero = 0; + return linearizeImpl(offsets, basis, zero); +} + +SmallVector mlir::delinearize(int64_t linearIndex, + ArrayRef strides) { + assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && + "strides must be nonnegative"); + return delinearizeImpl(linearIndex, strides, + [](int64_t e1, int64_t e2) { return e1 / e2; }); +} + std::optional> mlir::computeShapeRatio(ArrayRef shape, ArrayRef subShape) { if (shape.size() < subShape.size()) @@ -60,35 +133,68 @@ return SmallVector{result.rbegin(), result.rend()}; } -int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { - assert(offsets.size() == basis.size()); - int64_t linearIndex = 0; - for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) - linearIndex += offsets[idx] * basis[idx]; - return linearIndex; +//===----------------------------------------------------------------------===// +// Utils that operate on AffineExpr. +//===----------------------------------------------------------------------===// + +SmallVector mlir::computeSuffixProduct(ArrayRef sizes) { + if (sizes.empty()) + return {}; + MLIRContext *ctx; + AffineExpr unit = getAffineConstantExpr(1, ctx); + return ::computeSuffixProductImpl(sizes, unit); } -llvm::SmallVector mlir::delinearize(ArrayRef sliceStrides, - int64_t index) { - int64_t rank = sliceStrides.size(); - SmallVector vectorOffsets(rank); - for (int64_t r = 0; r < rank; ++r) { - assert(sliceStrides[r] > 0); - vectorOffsets[r] = index / sliceStrides[r]; - index %= sliceStrides[r]; - } - return vectorOffsets; +SmallVector mlir::computeElementwiseMul(ArrayRef v1, + ArrayRef v2) { + return computeElementwiseMulImpl(v1, v2); } -int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { +AffineExpr mlir::computeMaxLinearIndex(MLIRContext *ctx, + ArrayRef basis) { if (basis.empty()) - return 0; - return std::accumulate(basis.begin(), basis.end(), 1, - std::multiplies()); + return getAffineConstantExpr(0, ctx); + return std::accumulate(basis.begin(), basis.end(), + getAffineConstantExpr(1, ctx), + std::multiplies()); +} + +AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, + ArrayRef basis) { + AffineExpr zero = getAffineConstantExpr(0, ctx); + return linearizeImpl(offsets, basis, zero); } -llvm::SmallVector +AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, + ArrayRef basis) { + SmallVector basisExprs = llvm::to_vector(llvm::map_range( + basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); + return linearize(ctx, offsets, basisExprs); +} + +SmallVector mlir::delinearize(AffineExpr linearIndex, + ArrayRef strides) { + return delinearizeImpl( + linearIndex, strides, + [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); }); +} + +SmallVector mlir::delinearize(AffineExpr linearIndex, + ArrayRef strides) { + MLIRContext *ctx = linearIndex.getContext(); + SmallVector basisExprs = llvm::to_vector(llvm::map_range( + strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); })); + return delinearize(linearIndex, ArrayRef{basisExprs}); +} + +//===----------------------------------------------------------------------===// +// Permutation utils. +//===----------------------------------------------------------------------===// + +SmallVector mlir::invertPermutationVector(ArrayRef permutation) { + assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && + "permutation must be non-negative"); SmallVector inversion(permutation.size()); for (const auto &pos : llvm::enumerate(permutation)) { inversion[pos.value()] = pos.index(); @@ -97,6 +203,8 @@ } bool mlir::isPermutationVector(ArrayRef interchange) { + assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && + "permutation must be non-negative"); llvm::SmallDenseSet seenVals; for (auto val : interchange) { if (seenVals.count(val)) @@ -106,9 +214,9 @@ return seenVals.size() == interchange.size(); } -llvm::SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront, - unsigned dropBack) { +SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront, + unsigned dropBack) { assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); auto range = arrayAttr.getAsRange(); SmallVector res; @@ -118,26 +226,3 @@ res.push_back((*it).getValue().getSExtValue()); return res; } - -mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef basis, - mlir::Builder &b) { - AffineExpr resultExpr = b.getAffineDimExpr(0); - resultExpr = resultExpr * basis[0]; - for (unsigned i = 1; i < basis.size(); i++) - resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i]; - return resultExpr; -} - -llvm::SmallVector -mlir::getDelinearizedAffineExpr(mlir::ArrayRef strides, Builder &b) { - AffineExpr resultExpr = b.getAffineDimExpr(0); - int64_t rank = strides.size(); - SmallVector vectorOffsets(rank); - vectorOffsets[0] = resultExpr.floorDiv(strides[0]); - resultExpr = resultExpr % strides[0]; - for (unsigned i = 1; i < rank; i++) { - vectorOffsets[i] = resultExpr.floorDiv(strides[i]); - resultExpr = resultExpr % strides[i]; - } - return vectorOffsets; -} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1558,7 +1558,7 @@ getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); } std::reverse(newStrides.begin(), newStrides.end()); - SmallVector newPosition = delinearize(newStrides, position); + SmallVector newPosition = delinearize(position, newStrides); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp->setAttr(ExtractOp::getPositionAttrStrName(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -457,7 +457,7 @@ for (int64_t linearIdx = 0; linearIdx < numTransposedElements; ++linearIdx) { - auto extractIdxs = delinearize(prunedInStrides, linearIdx); + auto extractIdxs = delinearize(linearIdx, prunedInStrides); SmallVector insertIdxs(extractIdxs); applyPermutationToVector(insertIdxs, prunedTransp); Value extractOp = @@ -588,8 +588,7 @@ loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); - Value x = - rewriter.create(loc, op.getLhs(), pos); + Value x = rewriter.create(loc, op.getLhs(), pos); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -31,7 +31,7 @@ static SmallVector getVectorOffset(ArrayRef ratioStrides, int64_t index, ArrayRef targetShape) { - return computeElementwiseMul(delinearize(ratioStrides, index), targetShape); + return computeElementwiseMul(delinearize(index, ratioStrides), targetShape); } /// A functor that accomplishes the same thing as `getVectorOffset` but