diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -28,8 +28,37 @@ /// Given the strides together with a linear index in the dimension /// space, returns the vector-space offsets in each dimension for a /// de-linearized index. -SmallVector delinearize(ArrayRef strides, - int64_t linearIndex); +SmallVector delinearize(ArrayRef strides, + int64_t linearIndex); + +/// Given a set of sizes, compute and return the strides (i.e. the number of +/// linear incides to skip along the (k-1) most minor dimensions to get the the +/// next k-slice). +/// This is also the basis that one can use to linearize an n-D offset confined +/// to `[0 .. sizes]`. +SmallVector computeStrides(ArrayRef sizes); + +/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise. +SmallVector computeElementwiseMul(ArrayRef v1, + ArrayRef v2); + +/// Compute and return the multi-dimensional integral ratio of `subShape` to +/// the trailing dimensions of `shape`. This represents how many times subShape` +/// fits within `shape`. +/// If integral division is not possible, return None. +/// The trailing `subShape.size()` entries of both shapes are assumed (and +/// enforced) to only contain noonnegative values. +/// +/// Examples: +/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1} +/// - shapeRatio({3, 8}, {2, 5, 2}) returns None +/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16} +Optional> computeShapeRatio(ArrayRef shape, + ArrayRef subShape); + +/// Return the number of elements of basis (i.e. the max linear index). +/// Return `0` if `basis` is empty. +int64_t computeMaxLinearIndex(ArrayRef basis); /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. @@ -45,16 +74,15 @@ } /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. -SmallVector getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront = 0, - unsigned dropBack = 0); +SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, + unsigned dropBack = 0); /// Computes and returns linearized affine expression w.r.t. `basis`. mlir::AffineExpr getLinearAffineExpr(ArrayRef basis, mlir::Builder &b); /// Given the strides in the dimension space, returns the affine expressions for /// vector-space offsets in each dimension for a de-linearized index. -SmallVector +SmallVector getDelinearizedAffineExpr(ArrayRef strides, mlir::Builder &b); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -111,7 +111,7 @@ } using NativeShapeFnType = - std::function>(Operation *op)>; + std::function>(Operation *op)>; /// Function that returns the shape of the vector to unroll to for a given /// operation. The unrolling is aborted if the function returns `llvm::None`. NativeShapeFnType nativeShape = nullptr; @@ -122,8 +122,8 @@ /// Set the native shape to use for unrolling. UnrollVectorOptions &setNativeShape(ArrayRef shape) { - SmallVector tsShape(shape.begin(), shape.end()); - nativeShape = [=](Operation *) -> Optional> { + SmallVector tsShape(shape.begin(), shape.end()); + nativeShape = [=](Operation *) -> Optional> { return tsShape; }; return *this; diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -36,43 +36,6 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); } // namespace vector -/// Return the number of elements of basis, `0` if empty. -int64_t computeMaxLinearIndex(ArrayRef basis); - -/// Given the shape and sizes of a vector, returns the corresponding -/// strides for each dimension. -/// TODO: needs better doc of how it is used. -SmallVector computeStrides(ArrayRef shape, - ArrayRef sizes); - -/// Given the target sizes of a vector, together with vector-space offsets, -/// returns the element-space offsets for each dimension. -SmallVector -computeElementOffsetsFromVectorSliceOffsets(ArrayRef sizes, - ArrayRef vectorOffsets); - -/// Computes and returns the multi-dimensional ratio of `superShape` to -/// `subShape`. This is calculated by performing a traversal from minor to major -/// dimensions (i.e. in reverse shape order). If integral division is not -/// possible, returns None. -/// The ArrayRefs are assumed (and enforced) to only contain > 1 values. -/// This constraint comes from the fact that they are meant to be used with -/// VectorTypes, for which the property holds by construction. -/// -/// Examples: -/// - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4} -/// - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None -/// - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16} -Optional> shapeRatio(ArrayRef superShape, - ArrayRef subShape); - -/// Computes and returns the multi-dimensional ratio of the shapes of -/// `superVector` to `subVector`. If integral division is not possible, returns -/// None. -/// Assumes and enforces that the VectorTypes have the same elemental type. -Optional> shapeRatio(VectorType superVectorType, - VectorType subVectorType); - /// Constructs a permutation map of invariant memref indices to vector /// dimension. /// diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -80,8 +80,7 @@ Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, IntegerAttr::get(vecType.getElementType(), 0))); - SmallVector ones(shape.size(), 1); - SmallVector strides = computeStrides(shape, ones); + SmallVector strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(strides, linearIndex); SmallVector operands; diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -79,8 +79,7 @@ Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); - SmallVector ones(shape.size(), 1); - SmallVector strides = computeStrides(shape, ones); + SmallVector strides = computeStrides(shape); for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(strides, linearIndex); SmallVector operands; diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -127,15 +127,13 @@ // Iterate over all outer dimensions of the compute shape vector type. auto iterationDims = ArrayRef(expandedShape).drop_back(); - int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); - - SmallVector ones(iterationDims.size(), 1); - auto strides = computeStrides(iterationDims, ones); + int64_t maxIndex = computeMaxLinearIndex(iterationDims); + auto strides = computeStrides(iterationDims); // Compute results for each one dimensional vector. - SmallVector results(maxLinearIndex); + SmallVector results(maxIndex); - for (int64_t i = 0; i < maxLinearIndex; ++i) { + for (int64_t i = 0; i < maxIndex; ++i) { auto offsets = delinearize(strides, i); SmallVector extracted(expandedOperands.size()); @@ -152,7 +150,7 @@ Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); - for (int64_t i = 0; i < maxLinearIndex; ++i) + for (int64_t i = 0; i < maxIndex; ++i) result = builder.create(results[i], result, delinearize(strides, i)); diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -12,6 +12,54 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include + +using namespace mlir; + +SmallVector mlir::computeStrides(ArrayRef sizes) { + SmallVector strides(sizes.size(), 1); + for (int64_t r = strides.size() - 2; r >= 0; --r) + strides[r] = strides[r + 1] * sizes[r + 1]; + return strides; +} + +SmallVector mlir::computeElementwiseMul(ArrayRef v1, + ArrayRef v2) { + SmallVector result; + for (auto it : llvm::zip(v1, v2)) + result.push_back(std::get<0>(it) * std::get<1>(it)); + return result; +} + +Optional> +mlir::computeShapeRatio(ArrayRef shape, ArrayRef subShape) { + if (shape.size() < subShape.size()) + return None; + assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && + "shape must be nonnegative"); + assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && + "subShape must be nonnegative"); + + // Starting from the end, compute the integer divisors. + std::vector result; + result.reserve(shape.size()); + for (auto [superSize, subSize] : + llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) { + // If integral division does not occur, return and let the caller decide. + if (superSize % subSize != 0) + return None; + result.push_back(superSize / subSize); + } + // At this point we computed the ratio (in reverse) for the common + // size. Fill with the remaining entries from the super-vector shape (still in + // reverse). + int commonSize = subShape.size(); + std::copy(shape.rbegin() + commonSize, shape.rend(), + std::back_inserter(result)); + // Reverse again to get it back in the proper order and return. + return SmallVector{result.rbegin(), result.rend()}; +} + int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { assert(offsets.size() == basis.size()); int64_t linearIndex = 0; @@ -20,10 +68,10 @@ return linearIndex; } -llvm::SmallVector mlir::delinearize(ArrayRef sliceStrides, - int64_t index) { +llvm::SmallVector mlir::delinearize(ArrayRef sliceStrides, + int64_t index) { int64_t rank = sliceStrides.size(); - SmallVector vectorOffsets(rank); + SmallVector vectorOffsets(rank); for (int64_t r = 0; r < rank; ++r) { assert(sliceStrides[r] > 0); vectorOffsets[r] = index / sliceStrides[r]; @@ -32,12 +80,19 @@ return vectorOffsets; } -llvm::SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, - unsigned dropFront, - unsigned dropBack) { +int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { + if (basis.empty()) + return 0; + return std::accumulate(basis.begin(), basis.end(), 1, + std::multiplies()); +} + +llvm::SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront, + unsigned dropBack) { assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); auto range = arrayAttr.getAsRange(); - SmallVector res; + SmallVector res; res.reserve(arrayAttr.size() - dropFront - dropBack); for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; it != eit; ++it) @@ -54,11 +109,11 @@ return resultExpr; } -llvm::SmallVector +llvm::SmallVector mlir::getDelinearizedAffineExpr(mlir::ArrayRef strides, Builder &b) { AffineExpr resultExpr = b.getAffineDimExpr(0); int64_t rank = strides.size(); - SmallVector vectorOffsets(rank); + SmallVector vectorOffsets(rank); vectorOffsets[0] = resultExpr.floorDiv(strides[0]); resultExpr = resultExpr % strides[0]; for (unsigned i = 1; i < rank; i++) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -54,9 +54,9 @@ } // Helper to construct iterator types with one index removed. -static SmallVector adjustIter(ArrayAttr iteratorTypes, - int64_t index) { - SmallVector results; +static SmallVector adjustIter(ArrayAttr iteratorTypes, + int64_t index) { + SmallVector results; for (const auto &it : llvm::enumerate(iteratorTypes)) { int64_t idx = it.index(); if (idx == index) @@ -70,7 +70,7 @@ static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter) { auto *ctx = rewriter.getContext(); - SmallVector results; + SmallVector results; for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { int64_t idx = map.getDimPosition(i); if (idx == index) @@ -140,7 +140,7 @@ } template -static SmallVector extractVector(ArrayAttr arrayAttr) { +static SmallVector extractVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>(llvm::map_range( arrayAttr.getAsRange(), [](IntegerAttr attr) { return static_cast(attr.getInt()); })); @@ -399,7 +399,7 @@ VectorType resType = op.getResultType(); // Set up convenience transposition table. - SmallVector transp; + SmallVector transp; for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); @@ -430,12 +430,11 @@ // in vector form to improve performance. Therefore, we prune those // dimensions from the shape/transpose data structures used to generate the // extract/insert ops. - SmallVector prunedTransp; + SmallVector prunedTransp; pruneNonTransposedDims(transp, prunedTransp); size_t numPrunedDims = transp.size() - prunedTransp.size(); auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); - SmallVector ones(prunedInShape.size(), 1); - auto prunedInStrides = computeStrides(prunedInShape, ones); + auto prunedInStrides = computeStrides(prunedInShape); // Generates the extract/insert operations for every scalar/vector element // of the leftmost transposed dimensions. We traverse every transpose @@ -448,7 +447,7 @@ for (int64_t linearIdx = 0; linearIdx < numTransposedElements; ++linearIdx) { auto extractIdxs = delinearize(prunedInStrides, linearIdx); - SmallVector insertIdxs(extractIdxs); + SmallVector insertIdxs(extractIdxs); applyPermutationToVector(insertIdxs, prunedTransp); Value extractOp = rewriter.create(loc, input, extractIdxs); @@ -488,7 +487,7 @@ if (srcType.getRank() != 2) return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); - SmallVector transp; + SmallVector transp; for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); if (transp[0] != 1 && transp[1] != 0) @@ -685,8 +684,8 @@ bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); newLhs = rewriter.create(loc, newLhs, lhsTranspose); newRhs = rewriter.create(loc, newRhs, rhsTranspose); - SmallVector lhsOffsets(lhsReductionDims.size(), 0); - SmallVector rhsOffsets(rhsReductionDims.size(), 0); + SmallVector lhsOffsets(lhsReductionDims.size(), 0); + SmallVector rhsOffsets(rhsReductionDims.size(), 0); newLhs = rewriter.create( loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); newRhs = rewriter.create( @@ -752,7 +751,7 @@ if (rank == 1) { // Express constant 1-D case in explicit vector form: // [T,..,T,F,..,F]. - SmallVector values(dstType.getDimSize(0)); + SmallVector values(dstType.getDimSize(0)); for (int64_t d = 0; d < trueDim; d++) values[d] = true; rewriter.replaceOpWithNewOp( @@ -762,7 +761,7 @@ VectorType lowType = VectorType::get(dstType.getShape().drop_front(), eltType); - SmallVector newDimSizes; + SmallVector newDimSizes; for (int64_t r = 1; r < rank; r++) newDimSizes.push_back(dimSizes[r].cast().getInt()); Value trueVal = rewriter.create( @@ -931,8 +930,8 @@ // x[0,1,0] = y[0,2] // etc., incrementing the two index vectors "row-major" // within the source and result shape. - SmallVector srcIdx(srcRank); - SmallVector resIdx(resRank); + SmallVector srcIdx(srcRank); + SmallVector resIdx(resRank); Value result = rewriter.create( loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); for (int64_t i = 0; i < numElts; i++) { @@ -948,7 +947,7 @@ } private: - static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { + static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { assert(0 <= r && r < tp.getRank()); if (++idx[r] == tp.getDimSize(r)) { idx[r] = 0; @@ -1039,7 +1038,7 @@ LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - SmallVector maps = + SmallVector maps = llvm::to_vector<4>(contractOp.getIndexingMapsArray()); Value lhs = contractOp.getLhs(); Value rhs = contractOp.getRhs(); @@ -1169,7 +1168,7 @@ LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - SmallVector maps = + SmallVector maps = llvm::to_vector<4>(contractOp.getIndexingMapsArray()); Value lhs = contractOp.getLhs(); Value rhs = contractOp.getRhs(); @@ -1234,7 +1233,7 @@ for (auto &m : maps) m = compressDims(m, unusedDimsBitVector); // Compute the combined iterators. - SmallVector iterators; + SmallVector iterators; for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { if (!unusedDimsBitVector.test(i)) iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); @@ -1328,7 +1327,7 @@ // Make sure all operands are transpose/constant ops and collect their // transposition maps. - SmallVector transposeMaps; + SmallVector transposeMaps; transposeMaps.reserve(op->getNumOperands()); // Record the initial type before transposition. We'll use its shape later. // Any type will do here as we will check all transpose maps are the same. @@ -1350,7 +1349,7 @@ if (!llvm::all_equal(transposeMaps)) return rewriter.notifyMatchFailure(op, "different transpose map"); - SmallVector srcValues; + SmallVector srcValues; srcValues.reserve(op->getNumOperands()); // If there are constant operands, we need to insert inverse transposes for @@ -1724,7 +1723,7 @@ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); - SmallVector maps = op.getIndexingMapsArray(); + SmallVector maps = op.getIndexingMapsArray(); // // In the following we wish to make the reduction dimension innermost so we // can load vectors and just fmul + reduce into a scalar. @@ -1940,7 +1939,7 @@ VectorType rhsType = op.getRhsType(); VectorType resType = op.getResultType().cast(); // Find the iterator type index and result index. - SmallVector iMap = op.getIndexingMapsArray(); + SmallVector iMap = op.getIndexingMapsArray(); int64_t iterIndex = -1; int64_t dimSize = -1; if (lhsIndex >= 0) { @@ -2011,7 +2010,7 @@ bool isInt = resType.isa(); // Use iterator index 0. int64_t iterIndex = 0; - SmallVector iMap = op.getIndexingMapsArray(); + SmallVector iMap = op.getIndexingMapsArray(); Optional lookupLhs = getResultIndex(iMap[0], iterIndex); Optional lookupRhs = getResultIndex(iMap[1], iterIndex); if (!lookupLhs.has_value()) @@ -2087,7 +2086,7 @@ if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) return failure(); - SmallVector broadcastedDims; + SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. // We let the 0-d corner case pass-through as it is supported. @@ -2106,8 +2105,8 @@ // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. ArrayRef vectorShape = read.getVectorType().getShape(); - SmallVector unbroadcastedVectorShape(vectorShape.begin(), - vectorShape.end()); + SmallVector unbroadcastedVectorShape(vectorShape.begin(), + vectorShape.end()); for (unsigned i : broadcastedDims) unbroadcastedVectorShape[i] = 1; VectorType unbroadcastedVectorType = VectorType::get( @@ -2286,7 +2285,7 @@ }; // Returns the values in `arrayAttr` as an integer vector. -static SmallVector getIntValueVector(ArrayAttr arrayAttr) { +static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( llvm::map_range(arrayAttr.getAsRange(), [](IntegerAttr attr) { return attr.getInt(); })); @@ -2410,7 +2409,7 @@ // dimension's offset given we are extracting from less elements now. ArrayAttr newOffsets = extractOp.getOffsets(); if (newOffsets.size() == rank) { - SmallVector offsets = getIntValueVector(newOffsets); + SmallVector offsets = getIntValueVector(newOffsets); if (offsets.back() % expandRatio != 0) return failure(); offsets.back() = offsets.back() / expandRatio; @@ -2420,14 +2419,14 @@ // Similarly for sizes. ArrayAttr newSizes = extractOp.getSizes(); if (newSizes.size() == rank) { - SmallVector sizes = getIntValueVector(newSizes); + SmallVector sizes = getIntValueVector(newSizes); if (sizes.back() % expandRatio != 0) return failure(); sizes.back() = sizes.back() / expandRatio; newSizes = rewriter.getI64ArrayAttr(sizes); } - SmallVector dims = + SmallVector dims = llvm::to_vector<4>(extractOp.getType().cast().getShape()); dims.back() = dims.back() / expandRatio; VectorType newExtractType = @@ -2500,13 +2499,13 @@ ArrayAttr newOffsets = insertOp.getOffsets(); assert(newOffsets.size() == rank); - SmallVector offsets = getIntValueVector(newOffsets); + SmallVector offsets = getIntValueVector(newOffsets); if (offsets.back() % shrinkRatio != 0) return failure(); offsets.back() = offsets.back() / shrinkRatio; newOffsets = rewriter.getI64ArrayAttr(offsets); - SmallVector srcDims = + SmallVector srcDims = llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); srcDims.back() = srcDims.back() / shrinkRatio; VectorType newCastSrcType = @@ -2515,7 +2514,7 @@ auto newCastSrcOp = rewriter.create( bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); - SmallVector dstDims = + SmallVector dstDims = llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); dstDims.back() = dstDims.back() / shrinkRatio; VectorType newCastDstType = diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -27,24 +27,19 @@ /// During unrolling from `originalShape` to `targetShape` return the offset for /// the slice `index`. -static SmallVector getVectorOffset(ArrayRef originalShape, - ArrayRef targetShape, - int64_t index) { - SmallVector dstSliceStrides = - computeStrides(originalShape, targetShape); - SmallVector vectorOffsets = delinearize(dstSliceStrides, index); - SmallVector elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets); - return elementOffsets; +static SmallVector getVectorOffset(ArrayRef ratioStrides, + int64_t index, + ArrayRef targetShape) { + return computeElementwiseMul(delinearize(ratioStrides, index), targetShape); } -/// A functor that accomplishes the same thing as `getVectorOffset` but allows -/// for reordering the traversal of the dimensions. The order of traversal is -/// given in "for loop order" (outer to inner). +/// A functor that accomplishes the same thing as `getVectorOffset` but +/// allows for reordering the traversal of the dimensions. The order of +/// traversal is given in "for loop order" (outer to inner). namespace { class DecomposeShapeIterator { private: - SmallVector vectorShape; + SmallVector vectorShape; SmallVector loopOrder; SmallVector sliceStrides; int64_t maxIndexVal{1}; @@ -56,15 +51,15 @@ : vectorShape(targetShape.begin(), targetShape.end()), loopOrder(loopOrder.begin(), loopOrder.end()), sliceStrides(originalShape.size()) { - assert(originalShape.size() == targetShape.size()); - assert(loopOrder.size() == targetShape.size()); + assert(originalShape.size() >= targetShape.size()); + assert(loopOrder.size() == originalShape.size()); // Compute the count for each dimension. - SmallVector sliceDimCounts(originalShape.size()); - for (unsigned r = 0; r < originalShape.size(); ++r) { - sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); - maxIndexVal *= sliceDimCounts[r]; - } + auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape); + assert(maybeShapeRatio && "Shape does not evenly divide"); + // Pad `sliceDimCounts` with leading 1s so that all sizes match. + SmallVector sliceDimCounts = *maybeShapeRatio; + maxIndexVal = computeMaxLinearIndex(sliceDimCounts); // Reversing "loop order" gives dimensions from fastest varying to slowest // varying (smallest stride to largest stride). @@ -95,7 +90,7 @@ SmallVector getVectorOffset(int64_t index) const { SmallVector vectorOffsets = delinearize(index); SmallVector elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); + computeElementwiseMul(vectorShape, vectorOffsets); return elementOffsets; } }; @@ -139,7 +134,7 @@ /// Return the target shape for unrolling for the given `op`. Return llvm::None /// if the op shouldn't be or cannot be unrolled. -static Optional> +static Optional> getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { if (options.filterConstraint && failed(options.filterConstraint(op))) return llvm::None; @@ -152,10 +147,10 @@ auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) return llvm::None; - Optional> targetShape = options.nativeShape(op); + Optional> targetShape = options.nativeShape(op); if (!targetShape) return llvm::None; - auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); + auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return llvm::None; @@ -197,7 +192,7 @@ if (!targetShape) return failure(); auto sourceVectorType = readOp.getVectorType(); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); ArrayRef originalSize = readOp.getVectorType().getShape(); @@ -206,17 +201,16 @@ loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); - SmallVector originalIndices(readOp.getIndices().begin(), - readOp.getIndices().end()); + SmallVector originalIndices(readOp.getIndices().begin(), + readOp.getIndices().end()); SmallVector loopOrder = getUnrollOrder(originalSize.size(), readOp, options); DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, loopOrder); for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = - indexToOffsets.getVectorOffset(i); - SmallVector indices = + SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); + SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, readOp.getPermutationMap(), loc, rewriter); auto slicedRead = rewriter.create( @@ -255,11 +249,11 @@ if (!targetShape) return failure(); auto sourceVectorType = writeOp.getVectorType(); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); - SmallVector originalIndices(writeOp.getIndices().begin(), - writeOp.getIndices().end()); + SmallVector originalIndices(writeOp.getIndices().begin(), + writeOp.getIndices().end()); SmallVector loopOrder = getUnrollOrder(originalSize.size(), writeOp, options); @@ -267,11 +261,10 @@ loopOrder); Value resultTensor; for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = - indexToOffsets.getVectorOffset(i); + SmallVector elementOffsets = indexToOffsets.getVectorOffset(i); Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); - SmallVector indices = + SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, writeOp.getPermutationMap(), loc, rewriter); Operation *slicedWrite = rewriter.create( @@ -321,7 +314,7 @@ if (!targetShape) return failure(); auto dstVecType = contractOp.getResultType().cast(); - SmallVector originalSize = *contractOp.getShapeForUnroll(); + SmallVector originalSize = *contractOp.getShapeForUnroll(); Location loc = contractOp.getLoc(); unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); @@ -337,16 +330,16 @@ loopOrder); const int64_t sliceCount = indexToOffsets.maxIndex(); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = indexToOffsets.getVectorOffset(i); - SmallVector slicesOperands(contractOp.getNumOperands()); + SmallVector offsets = indexToOffsets.getVectorOffset(i); + SmallVector slicesOperands(contractOp.getNumOperands()); - // Helper to coompute the new shape of each operand and extract the slice. + // Helper to compute the new shape of each operand and extract the slice. auto extractOperand = [&](unsigned index, Value operand, AffineMap permutationMap, ArrayRef operandOffets) { SmallVector operandShape = applyPermutationMap( permutationMap, ArrayRef(*targetShape)); - SmallVector operandStrides(operandOffets.size(), 1); + SmallVector operandStrides(operandOffets.size(), 1); slicesOperands[index] = rewriter.create( loc, operand, operandOffets, operandShape, operandStrides); }; @@ -420,12 +413,12 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { - Optional> targetShape = + Optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) return failure(); - SmallVector originalSize = *reductionOp.getShapeForUnroll(); - SmallVector ratio = *shapeRatio(originalSize, *targetShape); + SmallVector originalSize = *reductionOp.getShapeForUnroll(); + SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> @@ -433,12 +426,16 @@ // Compute shape ratio of 'shape' and 'sizes'. int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = reductionOp.getLoc(); + + // Stride of the ratios, this gives us the offsets of sliceCount in a basis + // of multiples of the targetShape. + auto ratioStrides = computeStrides(ratio); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(originalSize, *targetShape, i); + SmallVector offsets = + getVectorOffset(ratioStrides, i, *targetShape); SmallVector operands; - SmallVector operandStrides(offsets.size(), 1); + SmallVector operandStrides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); operands.push_back(slicedOperand); @@ -451,7 +448,7 @@ } } Value acc; - SmallVector accStrides(destOffset.size(), 1); + SmallVector accStrides(destOffset.size(), 1); // If a version of the accumulator has already been computed, use it // otherwise extract the first version from the original operand. auto accIt = accCache.find(destOffset); @@ -500,21 +497,25 @@ if (!targetShape) return failure(); auto dstVecType = op->getResult(0).getType().cast(); - SmallVector originalSize = + SmallVector originalSize = *cast(op).getShapeForUnroll(); - SmallVector ratio = *shapeRatio(originalSize, *targetShape); + SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = op->getLoc(); // Prepare the result vector. Value result = rewriter.create( loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(targetShape->size(), 1); VectorType newVecType = VectorType::get(*targetShape, dstVecType.getElementType()); + + // Stride of the ratios, this gives us the offsets of sliceCount in a basis + // of multiples of the targetShape. + auto ratioStrides = computeStrides(ratio); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = - getVectorOffset(originalSize, *targetShape, i); - SmallVector extractOperands; + SmallVector offsets = + getVectorOffset(ratioStrides, i, *targetShape); + SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = operand.get().getType().template dyn_cast(); if (!vecType) { @@ -547,19 +548,24 @@ LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, PatternRewriter &rewriter) const override { - Optional> targetShape = + Optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); - int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0]; + auto ratio = *computeShapeRatio(originalSize, *targetShape); + int64_t sliceCount = ratio[0]; // Create unrolled vector reduction. Location loc = reductionOp.getLoc(); Value accumulator = nullptr; - for (int64_t i = 0; i < ratio; ++i) { + + // Stride of the ratios, this gives us the offsets of sliceCount in a basis + // of multiples of the targetShape. + auto ratioStrides = computeStrides(ratio); + for (int64_t i = 0; i < sliceCount; ++i) { SmallVector offsets = - getVectorOffset(originalSize, *targetShape, i); + getVectorOffset(ratioStrides, i, *targetShape); SmallVector strides(offsets.size(), 1); Value slicedOperand = rewriter.create( loc, reductionOp.getVector(), offsets, *targetShape, strides); @@ -600,21 +606,25 @@ if (!targetShape) return failure(); auto originalVectorType = tranposeOp.getResultType(); - SmallVector strides(targetShape->size(), 1); + SmallVector strides(targetShape->size(), 1); Location loc = tranposeOp.getLoc(); ArrayRef originalSize = originalVectorType.getShape(); - SmallVector ratio = *shapeRatio(originalSize, *targetShape); + SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); int64_t sliceCount = computeMaxLinearIndex(ratio); // Prepare the result vector; Value result = rewriter.create( loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); SmallVector permutation; tranposeOp.getTransp(permutation); + + // Stride of the ratios, this gives us the offsets of sliceCount in a basis + // of multiples of the targetShape. + auto ratioStrides = computeStrides(ratio); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector elementOffsets = - getVectorOffset(originalSize, *targetShape, i); - SmallVector permutedOffsets(elementOffsets.size()); - SmallVector permutedShape(elementOffsets.size()); + SmallVector elementOffsets = + getVectorOffset(ratioStrides, i, *targetShape); + SmallVector permutedOffsets(elementOffsets.size()); + SmallVector permutedShape(elementOffsets.size()); // Compute the source offsets and shape. for (auto &indices : llvm::enumerate(permutation)) { permutedOffsets[indices.value()] = elementOffsets[indices.index()]; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" @@ -25,7 +26,6 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/MathExtras.h" -#include #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -43,78 +43,6 @@ llvm_unreachable("Expected MemRefType or TensorType"); } -/// Return the number of elements of basis, `0` if empty. -int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { - if (basis.empty()) - return 0; - return std::accumulate(basis.begin(), basis.end(), 1, - std::multiplies()); -} - -SmallVector mlir::computeStrides(ArrayRef shape, - ArrayRef sizes) { - int64_t rank = shape.size(); - // Compute the count for each dimension. - SmallVector sliceDimCounts(rank); - for (int64_t r = 0; r < rank; ++r) - sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]); - // Use that to compute the slice stride for each dimension. - SmallVector sliceStrides(rank); - sliceStrides[rank - 1] = 1; - for (int64_t r = rank - 2; r >= 0; --r) - sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1]; - return sliceStrides; -} - -SmallVector mlir::computeElementOffsetsFromVectorSliceOffsets( - ArrayRef sizes, ArrayRef vectorOffsets) { - SmallVector result; - for (auto it : llvm::zip(vectorOffsets, sizes)) - result.push_back(std::get<0>(it) * std::get<1>(it)); - return result; -} - -Optional> mlir::shapeRatio(ArrayRef superShape, - ArrayRef subShape) { - if (superShape.size() < subShape.size()) { - return None; - } - - // Starting from the end, compute the integer divisors. - std::vector result; - result.reserve(superShape.size()); - for (auto [superSize, subSize] : - llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) { - assert(superSize > 0 && "superSize must be > 0"); - assert(subSize > 0 && "subSize must be > 0"); - - // If integral division does not occur, return and let the caller decide. - if (superSize % subSize != 0) - return None; - result.push_back(superSize / subSize); - } - - // At this point we computed the ratio (in reverse) for the common - // size. Fill with the remaining entries from the super-vector shape (still in - // reverse). - int commonSize = subShape.size(); - std::copy(superShape.rbegin() + commonSize, superShape.rend(), - std::back_inserter(result)); - - assert(result.size() == superShape.size() && - "super to sub shape ratio is not of the same size as the super rank"); - - // Reverse again to get it back in the proper order and return. - return SmallVector{result.rbegin(), result.rend()}; -} - -Optional> mlir::shapeRatio(VectorType superVectorType, - VectorType subVectorType) { - assert(superVectorType.getElementType() == subVectorType.getElementType() && - "vector types must be of the same elemental type"); - return shapeRatio(superVectorType.getShape(), subVectorType.getShape()); -} - /// Constructs a permutation map from memref indices to vector dimension. /// /// The implementation uses the knowledge of the mapping of enclosing loop to @@ -144,8 +72,8 @@ return AffineMap(); MLIRContext *context = enclosingLoopToVectorDim.begin()->getFirst()->getContext(); - SmallVector perm(enclosingLoopToVectorDim.size(), - getAffineConstantExpr(0, context)); + SmallVector perm(enclosingLoopToVectorDim.size(), + getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); @@ -252,7 +180,8 @@ } // Get the ratio. - auto ratio = shapeRatio(superVectorType, subVectorType); + auto ratio = + computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); // Sanity check. assert((ratio || !mustDivide) && diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" @@ -126,7 +127,8 @@ // purpose of this test. If we need to test more intricate behavior in the // future we can always extend. auto superVectorType = opInst->getResult(0).getType().cast(); - auto ratio = shapeRatio(superVectorType, subVectorType); + auto ratio = + computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); if (!ratio) { opInst->emitRemark("NOT MATCHED"); } else { diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -72,11 +72,11 @@ private: // Return the target shape based on op type. - static Optional> getShape(Operation *op) { + static Optional> getShape(Operation *op) { if (isa(op)) - return SmallVector(2, 2); + return SmallVector(2, 2); if (isa(op)) - return SmallVector(3, 2); + return SmallVector(3, 2); // For transfer ops, just propagate the shape coming from // InsertStridedSlices/ExtractStridedSlices. if (auto readOp = dyn_cast(op)) { @@ -90,15 +90,15 @@ return llvm::None; dstVec = vecType; } - return SmallVector(dstVec.getShape().begin(), - dstVec.getShape().end()); + return SmallVector(dstVec.getShape().begin(), + dstVec.getShape().end()); } if (auto writeOp = dyn_cast(op)) { auto insert = writeOp.getVector().getDefiningOp(); if (!insert) return llvm::None; ArrayRef shape = insert.getSourceVectorType().getShape(); - return SmallVector(shape.begin(), shape.end()); + return SmallVector(shape.begin(), shape.end()); } return llvm::None; } @@ -314,10 +314,10 @@ if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = - [](Operation *op) -> Optional> { + [](Operation *op) -> Optional> { vector::ContractionOp contractOp = cast(op); - SmallVector nativeShape( - contractOp.getIteratorTypes().size(), 4); + SmallVector nativeShape(contractOp.getIteratorTypes().size(), + 4); Type lhsType = contractOp.getLhsType().getElementType(); nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; return nativeShape; @@ -339,12 +339,11 @@ } populateVectorUnrollPatterns(patterns, opts); } else { - auto nativeShapeFn = - [](Operation *op) -> Optional> { + auto nativeShapeFn = [](Operation *op) -> Optional> { auto contractOp = dyn_cast(op); if (!contractOp) return None; - return SmallVector(contractOp.getIteratorTypes().size(), 2); + return SmallVector(contractOp.getIteratorTypes().size(), 2); }; populateVectorUnrollPatterns(patterns, UnrollVectorOptions() diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -61,7 +61,6 @@ LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; void replaceLinalgTransformationFilter(PatternRewriter &rewriter, Operation *op) const; - bool hasReplacementFilter(Operation *op) const; LinalgTransformationFilter &addFilter(const FilterFunction &f) { if (f) @@ -100,15 +99,6 @@ : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), replacement(replacement), matchByDefault(false) {} -LinalgTransformationFilter::LinalgTransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction, - Optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) { - if (f) - filters.push_back(f); -} - LogicalResult LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter, Operation *op) const { @@ -150,13 +140,6 @@ op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker)); } -bool LinalgTransformationFilter::hasReplacementFilter(Operation *op) const { - if (!replacement) - return false; - auto attr = op->getAttr(kLinalgTransformMarker).dyn_cast(); - return attr && attr == *replacement; -} - /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while /// using a `filter` to avoid recursive application.