diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -29,14 +29,14 @@ }]; } -def ApplyRankReducingSubviewPatternsOp : +def ApplyRankReducingSubviewPatternsOp : TransformWithPatternsOp<"vector.apply_rank_reducing_subview_patterns"> { let description = [{ Apply opt-in vector transfer permutation patterns that include: - TransferReadDropUnitDimsPattern - TransferWriteDropUnitDimsPattern - - These patterns have the effect of rewriting a vector.transfer with unit + + These patterns have the effect of rewriting a vector.transfer with unit dimensions into a rank-reduced version thanks to subview operations. This is complemented by shape_cast folding patterns. }]; @@ -51,7 +51,7 @@ }]; } -def ApplyTransferPermutationPatternsOp : +def ApplyTransferPermutationPatternsOp : TransformWithPatternsOp<"vector.apply_transfer_permutation_patterns"> { let description = [{ Apply opt-in vector transfer permutation patterns that include: @@ -59,8 +59,8 @@ - TransferWritePermutationLowering - TransferOpReduceRank - TransferWriteNonPermutationLowering - - These patterns have the effect of rewriting a vector.transfer with an + + These patterns have the effect of rewriting a vector.transfer with an arbitrary permutation_map to a vector.transfer with a permutation_map that is a minor identity followed by a vector.transpose. @@ -100,7 +100,7 @@ // TODO: evolve lowering_strategy to proper enums. def LowerContractionOp : TransformWithPatternsOp<"vector.lower_contraction"> { let description = [{ - Indicates that the vector contraction-like operations nested under the + Indicates that the vector contraction-like operations nested under the isolated from above op `target` should be lowered to finer-grained vector primitives. @@ -164,7 +164,7 @@ def MaterializeMasksOp : TransformWithPatternsOp<"vector.materialize_masks"> { let description = [{ - Indicates that mask operations nested under the isolated from above op + Indicates that mask operations nested under the isolated from above op `target` should be lowered to fine-grained arithemtic operations. This is usually the last step that is run after bufferization as part of the @@ -185,7 +185,7 @@ def LowerMultiReductionOp : TransformWithPatternsOp<"vector.lower_multi_reduction"> { let description = [{ - Indicates that the vector multi_reduction-like operations nested under the + Indicates that the vector multi_reduction-like operations nested under the isolated from above op `target` should be lowered to finer-grained vector primitives. @@ -229,7 +229,7 @@ def LowerShapeCastOp : TransformWithPatternsOp<"vector.lower_shape_cast"> { let description = [{ - Indicates that the vector shape_cast operations nested under the + Indicates that the vector shape_cast operations nested under the isolated from above op `target` should be lowered to finer-grained vector primitives. @@ -249,7 +249,7 @@ def LowerTransferOp : TransformWithPatternsOp<"vector.lower_transfer"> { let description = [{ - Indicates that the vector transfer operations nested under the + Indicates that the vector transfer operations nested under the isolated from above op `target` should be lowered to finer-grained vector primitives. @@ -273,7 +273,7 @@ // TODO: evolve lowering_strategy to proper enums. def LowerTransposeOp : TransformWithPatternsOp<"vector.lower_transpose"> { let description = [{ - Indicates that the vector transpose-like operations nested under the + Indicates that the vector transpose-like operations nested under the isolated from above op `target` should be lowered to finer-grained vector primitives. @@ -303,7 +303,7 @@ def SplitTransferFullPartialOp : TransformWithPatternsOp<"vector.split_transfer_full_partial"> { let description = [{ - Indicates that the vector transfer operations nested under the + Indicates that the vector transfer operations nested under the isolated from above op `target` should be split to full and partial parts. This is usually a late step that is run after bufferization as part of the @@ -326,7 +326,7 @@ def TransferToScfOp : TransformWithPatternsOp<"vector.transfer_to_scf"> { let description = [{ - Indicates that the vector transfer operations nested under the + Indicates that the vector transfer operations nested under the isolated from above op `target` should be rewritten with scf.for loops over finer-grained vector primitives. @@ -351,4 +351,31 @@ }]; } +def UnrollOp : Op { + let description = [{ + Unrolls the `target` using the provided `unroll_shape` and `unroll_order`. + The `unroll_shape` size must be less-than-or-equal-to the size of the + target's canonical shape provided by `target.getUnrollShape()`. The + `unroll_order` controls the order in which the unrolling occurs by + specifying a permutation on the dimensions (in the order of slowest to + fastest varying). + }]; + let arguments = (ins TransformHandleTypeInterface:$target, + DenseI64ArrayAttr:$unroll_shape, + OptionalAttr:$unroll_order); + let results = (outs ); + let assemblyFormat = "$target attr-dict `:` type($target)"; + let extraClassDeclaration = [{ + DiagnosedSilenceableFailure applyToOne( + VectorUnrollOpInterface target, + transform::ApplyToEachResultList &transformResults, + transform::TransformState &state + ); + }]; +} + #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Interfaces/VectorInterfaces.h" namespace mlir { class MLIRContext; @@ -106,6 +107,27 @@ /// optimizations. void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp); +/// Encapsulates the result of an unrolling transformation. The `unrolledOps` +/// are of the unrolled ops of the same operation type as the target of the +/// unrolling function. The `replacements` is what should be used to replace the +/// result of the unrolled operation.a +struct UnrollResult { + SmallVector unrolledOps; + Value replacement; + UnrollResult(ArrayRef unrolledOps, Value replacement) + : unrolledOps(unrolledOps), replacement(replacement) {} +}; + +/// Unrolls the `target` using the provided target shape and unroll order. The +/// `targetShape` size must be less-than-or-equal-to the size of the target's +/// canonical shape provided by `target.getUnrollShape()`. The `unrollOrder` +/// controls the order in which the unrolling occurs by specifying a permutation +/// on the dimensions (in the order of slowest to fastest varying). +FailureOr unroll(RewriterBase &rewriter, + VectorUnrollOpInterface target, + ArrayRef targetShape, + ArrayRef unrollOrder); + } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -16,6 +16,8 @@ #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -83,7 +85,8 @@ // MaterializeMasksOp //===----------------------------------------------------------------------===// -void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) { +void transform::MaterializeMasksOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskMaterializationPatterns(patterns, /*force32BitVectorIndices=*/false); } @@ -170,6 +173,33 @@ populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); } +//===----------------------------------------------------------------------===// +// UnrollOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::UnrollOp::applyToOne( + VectorUnrollOpInterface target, + transform::ApplyToEachResultList &transformResults, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + SmallVector unrollOrder = + getUnrollOrder() + ? llvm::to_vector(*getUnrollOrder()) + : llvm::to_vector(llvm::seq(0, getUnrollShape().size())); + FailureOr result = + vector::unroll(rewriter, target, getUnrollShape(), unrollOrder); + if (failed(result)) { + target->emitOpError("failed to apply unrolling transformation"); + return emitDefaultDefiniteFailure(target); + } + if (target->getNumResults() == 0) + rewriter.eraseOp(target); + else + rewriter.replaceOp(target, result->replacement); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -31,7 +32,7 @@ ArrayRef indices, AffineMap permutationMap, Location loc, - OpBuilder &builder) { + RewriterBase &builder) { MLIRContext *ctx = builder.getContext(); auto isBroadcast = [](AffineExpr expr) { if (auto constExpr = expr.dyn_cast()) @@ -39,7 +40,7 @@ return false; }; // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector slicedIndices(indices.begin(), indices.end()); + SmallVector slicedIndices(indices.begin(), indices.end()); for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { if (isBroadcast(dim.value())) continue; @@ -47,170 +48,163 @@ auto expr = getAffineDimExpr(0, builder.getContext()) + getAffineConstantExpr(elementOffsets[dim.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - slicedIndices[pos] = - builder.create(loc, map, indices[pos]); + slicedIndices[pos] = affine::makeComposedFoldedAffineApply( + builder, loc, map, OpFoldResult(indices[pos])); } - return slicedIndices; + return getValueOrCreateConstantIndexOp(builder, loc, slicedIndices); } // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. -static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, - Operation *op, +static Operation *cloneOpWithOperandsAndTypes(RewriterBase &builder, + Location loc, Operation *op, ArrayRef operands, ArrayRef resultTypes) { return builder.create(loc, op->getName().getIdentifier(), operands, resultTypes, op->getAttrs()); } +/// Returns true if `op` implements VectorUnrollOpInterface and `targetShape` +/// is valid for unrolling `op`. If `trivialUnrollValid` is true, then an unroll +/// that will result in just one operation (for example, if the shapes are +/// equal) is considered valid, otherwise it is invalid. +static bool isValidTargetUnrollShape(ArrayRef targetShape, + Operation *op, + bool trivialUnrollValid = true) { + auto unrollableVectorOp = dyn_cast(op); + if (!unrollableVectorOp) + return false; + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + if (!maybeUnrollShape) + return false; + auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, targetShape); + if (!maybeShapeRatio) + return false; + // Reject the trivial case if required. This is required in pattern rewrites + // where we want to abort when the unrolling is identity transform. + if (!trivialUnrollValid && + llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) + return false; + return true; +} + /// Return the target shape for unrolling for the given `op`. Return /// std::nullopt if the op shouldn't be or cannot be unrolled. static std::optional> -getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { +getTargetShapeFromOptions(const vector::UnrollVectorOptions &options, + Operation *op) { if (options.filterConstraint && failed(options.filterConstraint(op))) return std::nullopt; assert(options.nativeShape && "vector unrolling expects the native shape or native" "shape call back function to be set"); - auto unrollableVectorOp = dyn_cast(op); - if (!unrollableVectorOp) - return std::nullopt; - auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); - if (!maybeUnrollShape) - return std::nullopt; std::optional> targetShape = options.nativeShape(op); - if (!targetShape) - return std::nullopt; - auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); - if (!maybeShapeRatio || - llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) + if (!targetShape || + !isValidTargetUnrollShape(*targetShape, op, /*equalShapeValid=*/false)) return std::nullopt; return targetShape; } +/// Return the unroll order that should be used for `op`, or if no unroll order +/// is specified via the callback, use the identity order. static SmallVector -getUnrollOrder(unsigned numLoops, Operation *op, - const vector::UnrollVectorOptions &options) { - SmallVector loopOrder = - llvm::to_vector(llvm::seq(0, static_cast(numLoops))); +getUnrollOrderFromOptions(Operation *op, + const vector::UnrollVectorOptions &options) { if (options.traversalOrderCallback != nullptr) { std::optional> order = options.traversalOrderCallback(op); - if (order) { - loopOrder = std::move(*order); - } + if (order) + return *order; } - return loopOrder; + auto unrollShape = dyn_cast(op).getShapeForUnroll(); + assert(unrollShape && "expected valid unroll shape"); + return llvm::to_vector( + llvm::seq(0, static_cast(unrollShape->size()))); } -namespace { - -struct UnrollTransferReadPattern - : public OpRewritePattern { - UnrollTransferReadPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::TransferReadOp readOp, - PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (readOp.getTransferRank() == 0) - return failure(); - if (readOp.getMask()) - return failure(); - auto targetShape = getTargetShape(options, readOp); - if (!targetShape) - return failure(); - auto sourceVectorType = readOp.getVectorType(); - SmallVector strides(targetShape->size(), 1); - Location loc = readOp.getLoc(); - ArrayRef originalSize = readOp.getVectorType().getShape(); - - // Prepare the result vector; - Value result = rewriter.create( - loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); - auto targetType = - VectorType::get(*targetShape, sourceVectorType.getElementType()); - SmallVector originalIndices(readOp.getIndices().begin(), - readOp.getIndices().end()); - SmallVector loopOrder = - getUnrollOrder(originalSize.size(), readOp, options); - for (SmallVector elementOffsets : - StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { - SmallVector indices = - sliceTransferIndices(elementOffsets, originalIndices, - readOp.getPermutationMap(), loc, rewriter); - auto slicedRead = rewriter.create( - loc, targetType, readOp.getSource(), indices, - readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), - readOp.getInBoundsAttr()); - - result = rewriter.create( - loc, slicedRead, result, elementOffsets, strides); - } - rewriter.replaceOp(readOp, result); - return success(); +/// Unroll a `vector.transfer_read`. +static FailureOr unrollTransferRead(RewriterBase &rewriter, + TransferReadOp readOp, + ArrayRef targetShape, + ArrayRef loopOrder) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(readOp); + if (!isValidTargetUnrollShape(targetShape, readOp)) + return failure(); + if (readOp.getTransferRank() == 0 && targetShape.empty()) + return UnrollResult({readOp}, Value(readOp)); + if (readOp.getTransferRank() == 0 || readOp.getMask()) + return failure(); + auto sourceVectorType = readOp.getVectorType(); + SmallVector strides(targetShape.size(), 1); + Location loc = readOp.getLoc(); + ArrayRef originalSize = readOp.getVectorType().getShape(); + + // Prepare the result vector; + Value result = rewriter.create( + loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + auto targetType = + VectorType::get(targetShape, sourceVectorType.getElementType()); + SmallVector originalIndices(readOp.getIndices().begin(), + readOp.getIndices().end()); + SmallVector transferReadOps; + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, targetShape, loopOrder)) { + SmallVector indices = + sliceTransferIndices(elementOffsets, originalIndices, + readOp.getPermutationMap(), loc, rewriter); + auto slicedRead = rewriter.create( + loc, targetType, readOp.getSource(), indices, + readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), + readOp.getInBoundsAttr()); + result = rewriter.create(loc, slicedRead, result, + elementOffsets, strides); + transferReadOps.push_back(slicedRead); } + return UnrollResult(transferReadOps, result); +} -private: - vector::UnrollVectorOptions options; -}; - -struct UnrollTransferWritePattern - : public OpRewritePattern { - UnrollTransferWritePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, - PatternRewriter &rewriter) const override { - // TODO: support 0-d corner case. - if (writeOp.getTransferRank() == 0) - return failure(); - - if (writeOp.getMask()) - return failure(); - auto targetShape = getTargetShape(options, writeOp); - if (!targetShape) - return failure(); - auto sourceVectorType = writeOp.getVectorType(); - SmallVector strides(targetShape->size(), 1); - Location loc = writeOp.getLoc(); - ArrayRef originalSize = sourceVectorType.getShape(); - SmallVector originalIndices(writeOp.getIndices().begin(), - writeOp.getIndices().end()); - SmallVector loopOrder = - getUnrollOrder(originalSize.size(), writeOp, options); - Value resultTensor; - for (SmallVector elementOffsets : - StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { - Value slicedVector = rewriter.create( - loc, writeOp.getVector(), elementOffsets, *targetShape, strides); - SmallVector indices = - sliceTransferIndices(elementOffsets, originalIndices, - writeOp.getPermutationMap(), loc, rewriter); - Operation *slicedWrite = rewriter.create( - loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), - indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); - // For the tensor case update the destination for the next transfer write. - if (!slicedWrite->getResults().empty()) - resultTensor = slicedWrite->getResult(0); - } - if (resultTensor) - rewriter.replaceOp(writeOp, resultTensor); - else - rewriter.eraseOp(writeOp); - return success(); +/// Unroll a `vector.transfer_write` op. +static FailureOr +unrollTransferWrite(RewriterBase &rewriter, TransferWriteOp writeOp, + ArrayRef targetShape, + ArrayRef loopOrder) { + if (!isValidTargetUnrollShape(targetShape, writeOp)) + return failure(); + /// We are already unrolled. + if (writeOp.getTransferRank() == 0 && targetShape.empty()) + return UnrollResult({writeOp}, Value(writeOp->getResult(0))); + if (writeOp.getTransferRank() == 0) + return failure(); + if (writeOp.getMask()) + return failure(); + auto sourceVectorType = writeOp.getVectorType(); + SmallVector strides(targetShape.size(), 1); + Location loc = writeOp.getLoc(); + ArrayRef originalSize = sourceVectorType.getShape(); + SmallVector originalIndices(writeOp.getIndices().begin(), + writeOp.getIndices().end()); + Value resultTensor; + SmallVector writeOps; + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, targetShape, loopOrder)) { + Value slicedVector = rewriter.create( + loc, writeOp.getVector(), elementOffsets, targetShape, strides); + SmallVector indices = + sliceTransferIndices(elementOffsets, originalIndices, + writeOp.getPermutationMap(), loc, rewriter); + auto slicedWrite = rewriter.create( + loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), + indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); + writeOps.push_back(slicedWrite); + // For the tensor case update the destination for the next transfer write. + if (!slicedWrite->getResults().empty()) + resultTensor = slicedWrite->getResult(0); } + return UnrollResult(writeOps, resultTensor); +} -private: - vector::UnrollVectorOptions options; -}; - +namespace { struct OffsetMapInfo { static SmallVector getEmptyKey() { return {int64_t(-1)}; } @@ -225,382 +219,355 @@ return lhs == rhs; } }; +} // namespace -struct UnrollContractionPattern - : public OpRewritePattern { - UnrollContractionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, contractOp); - if (!targetShape) - return failure(); - auto dstVecType = contractOp.getResultType().cast(); - SmallVector originalSize = *contractOp.getShapeForUnroll(); - - Location loc = contractOp.getLoc(); - unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); - AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; - llvm::MapVector< - SmallVector, Value, - llvm::DenseMap, unsigned, OffsetMapInfo>> - accCache; - - SmallVector loopOrder = getUnrollOrder( - contractOp.getIteratorTypes().size(), contractOp, options); - - for (SmallVector offsets : - StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { - SmallVector slicesOperands(contractOp.getNumOperands()); - - // Helper to compute the new shape of each operand and extract the slice. - auto extractOperand = [&](unsigned index, Value operand, - AffineMap permutationMap, - ArrayRef operandOffets) { - SmallVector operandShape = applyPermutationMap( - permutationMap, ArrayRef(*targetShape)); - SmallVector operandStrides(operandOffets.size(), 1); - slicesOperands[index] = rewriter.create( - loc, operand, operandOffets, operandShape, operandStrides); - }; - - // Extract the new lhs operand. - AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; - SmallVector lhsOffets = - applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); - extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); - - // Extract the new rhs operand. - AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; - SmallVector rhsOffets = - applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); - extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); - - AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; - SmallVector accOffets = - applyPermutationMap(accPermutationMap, ArrayRef(offsets)); - // If a version of the accumulator has already been computed, use it - // otherwise extract the first version from the original operand. - auto accIt = accCache.find(accOffets); - if (accIt != accCache.end()) - slicesOperands[2] = accIt->second; - else - extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); - - SmallVector dstShape = - applyPermutationMap(dstAffineMap, ArrayRef(*targetShape)); - auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); - Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, contractOp, slicesOperands, targetType); - - SmallVector dstOffets = - applyPermutationMap(dstAffineMap, ArrayRef(offsets)); - // Save the accumulated value untill all the loops are unrolled since - // reduction loop keep updating the accumulator. - accCache[dstOffets] = newOp->getResult(0); - } - // Assemble back the accumulator into a single vector. - Value result = rewriter.create( - loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - for (const auto &it : accCache) { - SmallVector dstStrides(it.first.size(), 1); - result = rewriter.create( - loc, it.second, result, it.first, dstStrides); - } - rewriter.replaceOp(contractOp, result); - return success(); +/// Unroll a `vector.contract` operation. +static FailureOr unrollContraction(RewriterBase &rewriter, + ContractionOp contractOp, + ArrayRef targetShape, + ArrayRef loopOrder) { + auto dstVecType = contractOp.getResultType().cast(); + SmallVector originalSize = *contractOp.getShapeForUnroll(); + + Location loc = contractOp.getLoc(); + unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); + AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; + llvm::MapVector, Value, + llvm::DenseMap, unsigned, OffsetMapInfo>> + accCache; + SmallVector newContractOps; + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, targetShape, loopOrder)) { + SmallVector slicesOperands(contractOp.getNumOperands()); + + // Helper to compute the new shape of each operand and extract the slice. + auto extractOperand = [&](unsigned index, Value operand, + AffineMap permutationMap, + ArrayRef operandOffets) { + SmallVector operandShape = + applyPermutationMap(permutationMap, ArrayRef(targetShape)); + SmallVector operandStrides(operandOffets.size(), 1); + slicesOperands[index] = rewriter.create( + loc, operand, operandOffets, operandShape, operandStrides); + }; + + // Extract the new lhs operand. + AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; + SmallVector lhsOffets = + applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); + extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); + + // Extract the new rhs operand. + AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; + SmallVector rhsOffets = + applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); + extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); + + AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; + SmallVector accOffets = + applyPermutationMap(accPermutationMap, ArrayRef(offsets)); + // If a version of the accumulator has already been computed, use it + // otherwise extract the first version from the original operand. + auto accIt = accCache.find(accOffets); + if (accIt != accCache.end()) + slicesOperands[2] = accIt->second; + else + extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); + + SmallVector dstShape = + applyPermutationMap(dstAffineMap, ArrayRef(targetShape)); + auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); + Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, contractOp, + slicesOperands, targetType); + newContractOps.push_back(newOp); + + SmallVector dstOffets = + applyPermutationMap(dstAffineMap, ArrayRef(offsets)); + // Save the accumulated value untill all the loops are unrolled since + // reduction loop keep updating the accumulator. + accCache[dstOffets] = newOp->getResult(0); } + // Assemble back the accumulator into a single vector. + Value result = rewriter.create( + loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + for (const auto &it : accCache) { + SmallVector dstStrides(it.first.size(), 1); + result = rewriter.create( + loc, it.second, result, it.first, dstStrides); + } + return UnrollResult(newContractOps, result); +} -private: - vector::UnrollVectorOptions options; -}; - -struct UnrollMultiReductionPattern - : public OpRewritePattern { - UnrollMultiReductionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, - PatternRewriter &rewriter) const override { - std::optional> targetShape = - getTargetShape(options, reductionOp); - if (!targetShape) - return failure(); - SmallVector originalSize = *reductionOp.getShapeForUnroll(); - llvm::MapVector< - SmallVector, Value, - llvm::DenseMap, unsigned, OffsetMapInfo>> - accCache; - Location loc = reductionOp.getLoc(); - - // Stride of the ratios, this gives us the offsets of sliceCount in a basis - // of multiples of the targetShape. - for (SmallVector offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { - SmallVector operands; - SmallVector operandStrides(offsets.size(), 1); - Value slicedOperand = rewriter.create( - loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); - operands.push_back(slicedOperand); - SmallVector dstShape; - SmallVector destOffset; - for (size_t i : llvm::seq(size_t(0), targetShape->size())) { - if (!reductionOp.isReducedDim(i)) { - destOffset.push_back(offsets[i]); - dstShape.push_back((*targetShape)[i]); - } +/// Unroll a `vector.multi_dim_reduction`. +static FailureOr +unrollMultiDimReduce(RewriterBase &rewriter, MultiDimReductionOp reductionOp, + ArrayRef targetShape, + ArrayRef unrollOrder) { + if (!isValidTargetUnrollShape(targetShape, reductionOp)) + return failure(); + SmallVector originalSize = *reductionOp.getShapeForUnroll(); + llvm::MapVector, Value, + llvm::DenseMap, unsigned, OffsetMapInfo>> + accCache; + Location loc = reductionOp.getLoc(); + + SmallVector newReduceOps; + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, targetShape)) { + SmallVector operands; + SmallVector operandStrides(offsets.size(), 1); + Value slicedOperand = rewriter.create( + loc, reductionOp.getSource(), offsets, targetShape, operandStrides); + operands.push_back(slicedOperand); + SmallVector dstShape; + SmallVector destOffset; + for (size_t i : llvm::seq(size_t(0), targetShape.size())) { + if (!reductionOp.isReducedDim(i)) { + destOffset.push_back(offsets[i]); + dstShape.push_back(targetShape[i]); } - Value acc; - SmallVector accStrides(destOffset.size(), 1); - // If a version of the accumulator has already been computed, use it - // otherwise extract the first version from the original operand. - auto accIt = accCache.find(destOffset); - if (accIt != accCache.end()) - acc = accIt->second; - else - acc = rewriter.create( - loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); - operands.push_back(acc); - auto targetType = VectorType::get( - dstShape, reductionOp.getSourceVectorType().getElementType()); - Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, - operands, targetType); - Value result = newOp->getResult(0); - accCache[destOffset] = result; } - // Assemble back the accumulator into a single vector. - Value result = rewriter.create( - loc, reductionOp.getDestType(), - rewriter.getZeroAttr(reductionOp.getDestType())); - for (const auto &it : accCache) { - SmallVector dstStrides(it.first.size(), 1); - result = rewriter.create( - loc, it.second, result, it.first, dstStrides); - } - rewriter.replaceOp(reductionOp, result); - return success(); + Value acc; + SmallVector accStrides(destOffset.size(), 1); + // If a version of the accumulator has already been computed, use it + // otherwise extract the first version from the original operand. + auto accIt = accCache.find(destOffset); + if (accIt != accCache.end()) + acc = accIt->second; + else + acc = rewriter.create( + loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); + operands.push_back(acc); + auto targetType = VectorType::get( + dstShape, reductionOp.getSourceVectorType().getElementType()); + Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, + operands, targetType); + Value result = newOp->getResult(0); + accCache[destOffset] = result; + newReduceOps.push_back(newOp); } + // Assemble back the accumulator into a single vector. + Value result = rewriter.create( + loc, reductionOp.getDestType(), + rewriter.getZeroAttr(reductionOp.getDestType())); + for (const auto &it : accCache) { + SmallVector dstStrides(it.first.size(), 1); + result = rewriter.create( + loc, it.second, result, it.first, dstStrides); + } + return UnrollResult(newReduceOps, result); +} -private: - vector::UnrollVectorOptions options; -}; - -struct UnrollElementwisePattern : public RewritePattern { - UnrollElementwisePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - options(options) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) - return failure(); - auto targetShape = getTargetShape(options, op); - if (!targetShape) - return failure(); - auto dstVecType = op->getResult(0).getType().cast(); - SmallVector originalSize = - *cast(op).getShapeForUnroll(); - Location loc = op->getLoc(); - // Prepare the result vector. - Value result = rewriter.create( - loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector strides(targetShape->size(), 1); - VectorType newVecType = - VectorType::get(*targetShape, dstVecType.getElementType()); - - // Create the unrolled computation. - for (SmallVector offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { - SmallVector extractOperands; - for (OpOperand &operand : op->getOpOperands()) { - auto vecType = operand.get().getType().template dyn_cast(); - if (!vecType) { - extractOperands.push_back(operand.get()); - continue; - } - extractOperands.push_back( - rewriter.create( - loc, operand.get(), offsets, *targetShape, strides)); +/// Unroll any op with the ElementwiseMappable trait that is performing some +/// computation on vector types. +static FailureOr +unrollElementwise(RewriterBase &rewriter, Operation *op, + ArrayRef targetShape, + ArrayRef unrollOrder) { + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + return failure(); + auto dstVecType = op->getResult(0).getType().cast(); + SmallVector originalSize = + *cast(op).getShapeForUnroll(); + Location loc = op->getLoc(); + // Prepare the result vector. + Value result = rewriter.create( + loc, dstVecType, rewriter.getZeroAttr(dstVecType)); + SmallVector strides(targetShape.size(), 1); + VectorType newVecType = + VectorType::get(targetShape, dstVecType.getElementType()); + + // Create the unrolled computation. + SmallVector newOps; + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, targetShape)) { + SmallVector extractOperands; + for (OpOperand &operand : op->getOpOperands()) { + auto vecType = operand.get().getType().template dyn_cast(); + if (!vecType) { + extractOperands.push_back(operand.get()); + continue; } - Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); - result = rewriter.create( - loc, newOp->getResult(0), result, offsets, strides); + extractOperands.push_back(rewriter.create( + loc, operand.get(), offsets, targetShape, strides)); } - rewriter.replaceOp(op, result); - return success(); + Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, op, + extractOperands, newVecType); + result = rewriter.create( + loc, newOp->getResult(0), result, offsets, strides); + newOps.push_back(newOp); } + return UnrollResult(newOps, result); +} -private: - vector::UnrollVectorOptions options; -}; - -struct UnrollReductionPattern : public OpRewritePattern { - UnrollReductionPattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, - PatternRewriter &rewriter) const override { - std::optional> targetShape = - getTargetShape(options, reductionOp); - if (!targetShape) - return failure(); - SmallVector originalSize = *reductionOp.getShapeForUnroll(); - - // Create unrolled vector reduction. - Location loc = reductionOp.getLoc(); - Value accumulator = nullptr; - for (SmallVector offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { - SmallVector strides(offsets.size(), 1); - Value slicedOperand = rewriter.create( - loc, reductionOp.getVector(), offsets, *targetShape, strides); - Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); - Value result = newOp->getResult(0); - - if (!accumulator) { - // This is the first reduction. - accumulator = result; - } else { - // On subsequent reduction, combine with the accumulator. - accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), - accumulator, result); - } +/// Unroll a `vector.reduce` operation. +static FailureOr unrollReduction(RewriterBase &rewriter, + ReductionOp reductionOp, + ArrayRef targetShape, + ArrayRef unrollOrder) { + // Create unrolled vector reduction. + SmallVector originalSize = *reductionOp.getShapeForUnroll(); + Location loc = reductionOp.getLoc(); + Value accumulator = nullptr; + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, targetShape)) { + SmallVector strides(offsets.size(), 1); + Value slicedOperand = rewriter.create( + loc, reductionOp.getVector(), offsets, targetShape, strides); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); + Value result = newOp->getResult(0); + + if (!accumulator) { + // This is the first reduction. + accumulator = result; + } else { + // On subsequent reduction, combine with the accumulator. + accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), + accumulator, result); } - - rewriter.replaceOp(reductionOp, accumulator); - return success(); } + return UnrollResult({}, accumulator); +} -private: - const vector::UnrollVectorOptions options; -}; - -struct UnrollTransposePattern : public OpRewritePattern { - UnrollTransposePattern(MLIRContext *context, - const vector::UnrollVectorOptions &options, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - options(options) {} - - LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - if (transposeOp.getResultVectorType().getRank() == 0) - return failure(); - auto targetShape = getTargetShape(options, transposeOp); - if (!targetShape) - return failure(); - auto originalVectorType = transposeOp.getResultVectorType(); - SmallVector strides(targetShape->size(), 1); - Location loc = transposeOp.getLoc(); - ArrayRef originalSize = originalVectorType.getShape(); - - // Prepare the result vector; - Value result = rewriter.create( - loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); - SmallVector permutation; - transposeOp.getTransp(permutation); - - // Unroll the computation. - for (SmallVector elementOffsets : - StaticTileOffsetRange(originalSize, *targetShape)) { - SmallVector permutedOffsets(elementOffsets.size()); - SmallVector permutedShape(elementOffsets.size()); - // Compute the source offsets and shape. - for (auto indices : llvm::enumerate(permutation)) { - permutedOffsets[indices.value()] = elementOffsets[indices.index()]; - permutedShape[indices.value()] = (*targetShape)[indices.index()]; - } - Value slicedOperand = rewriter.create( - loc, transposeOp.getVector(), permutedOffsets, permutedShape, - strides); - Value transposedSlice = - rewriter.create(loc, slicedOperand, permutation); - result = rewriter.create( - loc, transposedSlice, result, elementOffsets, strides); +static FailureOr unrollTranspose(RewriterBase &rewriter, + TransposeOp transposeOp, + ArrayRef targetShape, + ArrayRef loopOrder) { + auto originalVectorType = transposeOp.getResultVectorType(); + SmallVector strides(targetShape.size(), 1); + Location loc = transposeOp.getLoc(); + ArrayRef originalSize = originalVectorType.getShape(); + + // Prepare the result vector; + Value result = rewriter.create( + loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); + SmallVector permutation; + transposeOp.getTransp(permutation); + + // Unroll the computation. + SmallVector newTransposeOps; + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, targetShape)) { + SmallVector permutedOffsets(elementOffsets.size()); + SmallVector permutedShape(elementOffsets.size()); + // Compute the source offsets and shape. + for (auto indices : llvm::enumerate(permutation)) { + permutedOffsets[indices.value()] = elementOffsets[indices.index()]; + permutedShape[indices.value()] = targetShape[indices.index()]; } - rewriter.replaceOp(transposeOp, result); - return success(); + Value slicedOperand = rewriter.create( + loc, transposeOp.getVector(), permutedOffsets, permutedShape, strides); + auto transposedSlice = + rewriter.create(loc, slicedOperand, permutation); + result = rewriter.create( + loc, transposedSlice, result, elementOffsets, strides); + newTransposeOps.push_back(transposedSlice); } + return UnrollResult(newTransposeOps, result); +} -private: - vector::UnrollVectorOptions options; -}; +/// Unroll a `vector.gather` operation. +static FailureOr unrollGather(RewriterBase &rewriter, + GatherOp gatherOp, + ArrayRef targetShape, + ArrayRef loopOrder) { + VectorType sourceVectorType = gatherOp.getVectorType(); + SmallVector strides(targetShape.size(), 1); + Location loc = gatherOp.getLoc(); + ArrayRef originalSize = gatherOp.getVectorType().getShape(); + + // Prepare the result vector; + Value result = rewriter.create( + loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); + auto targetType = + VectorType::get(targetShape, sourceVectorType.getElementType()); + SmallVector newGatherOps; + for (SmallVector elementOffsets : + StaticTileOffsetRange(originalSize, targetShape, loopOrder)) { + // To get the unrolled gather, extract the same slice based on the + // decomposed shape from each of the index, mask, and pass-through + // vectors. + Value indexSubVec = rewriter.create( + loc, gatherOp.getIndexVec(), elementOffsets, targetShape, strides); + Value maskSubVec = rewriter.create( + loc, gatherOp.getMask(), elementOffsets, targetShape, strides); + Value passThruSubVec = rewriter.create( + loc, gatherOp.getPassThru(), elementOffsets, targetShape, strides); + auto slicedGather = rewriter.create( + loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), indexSubVec, + maskSubVec, passThruSubVec); + newGatherOps.push_back(slicedGather); + + result = rewriter.create( + loc, slicedGather, result, elementOffsets, strides); + } + return UnrollResult(newGatherOps, result); +} + +FailureOr vector::unroll(RewriterBase &rewriter, + VectorUnrollOpInterface target, + ArrayRef targetShape, + ArrayRef unrollOrder) { + FailureOr result = failure(); + if (auto readOp = dyn_cast(*target)) + result = unrollTransferRead(rewriter, readOp, targetShape, unrollOrder); + if (auto writeOp = dyn_cast(*target)) + result = unrollTransferWrite(rewriter, writeOp, targetShape, unrollOrder); + if (auto contractOp = dyn_cast(*target)) + result = unrollContraction(rewriter, contractOp, targetShape, unrollOrder); + if (auto redOp = dyn_cast(*target)) + result = unrollMultiDimReduce(rewriter, redOp, targetShape, unrollOrder); + if (auto reductionOp = dyn_cast(*target)) + result = unrollReduction(rewriter, reductionOp, targetShape, unrollOrder); + if (auto transposeOp = dyn_cast(*target)) + result = unrollTranspose(rewriter, transposeOp, targetShape, unrollOrder); + if (auto gatherOp = dyn_cast(*target)) + return unrollGather(rewriter, gatherOp, targetShape, unrollOrder); + if (OpTrait::hasElementwiseMappableTraits(target)) + result = unrollElementwise(rewriter, target, targetShape, unrollOrder); + return result; +} -struct UnrollGatherPattern : public OpRewritePattern { - UnrollGatherPattern(MLIRContext *context, +namespace { +struct UnrollVectorRewrite + : public OpInterfaceRewritePattern { + UnrollVectorRewrite(MLIRContext *context, const vector::UnrollVectorOptions &options, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), options(options) { - } - - LogicalResult matchAndRewrite(vector::GatherOp gatherOp, + : OpInterfaceRewritePattern(context, benefit), + options(options) {} + LogicalResult matchAndRewrite(VectorUnrollOpInterface op, PatternRewriter &rewriter) const override { - VectorType sourceVectorType = gatherOp.getVectorType(); - if (sourceVectorType.getRank() == 0) - return failure(); - auto targetShape = getTargetShape(options, gatherOp); + std::optional> targetShape = + getTargetShapeFromOptions(options, op); if (!targetShape) - return failure(); - SmallVector strides(targetShape->size(), 1); - Location loc = gatherOp.getLoc(); - ArrayRef originalSize = gatherOp.getVectorType().getShape(); - - // Prepare the result vector; - Value result = rewriter.create( - loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); - auto targetType = - VectorType::get(*targetShape, sourceVectorType.getElementType()); - - SmallVector loopOrder = - getUnrollOrder(originalSize.size(), gatherOp, options); - for (SmallVector elementOffsets : - StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { - // To get the unrolled gather, extract the same slice based on the - // decomposed shape from each of the index, mask, and pass-through - // vectors. - Value indexSubVec = rewriter.create( - loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); - Value maskSubVec = rewriter.create( - loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); - Value passThruSubVec = rewriter.create( - loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); - auto slicedGather = rewriter.create( - loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), - indexSubVec, maskSubVec, passThruSubVec); - - result = rewriter.create( - loc, slicedGather, result, elementOffsets, strides); + return rewriter.notifyMatchFailure( + op, "failed to get a valid unroll target shape"); + SmallVector unrollOrder = getUnrollOrderFromOptions(op, options); + FailureOr result = + unroll(rewriter, op, *targetShape, unrollOrder); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to apply vector unrolling transformation"); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + return success(); } - rewriter.replaceOp(gatherOp, result); + assert(op->getNumResults() == 1 && result->replacement && + "expected single-result op and a valid replacement"); + rewriter.replaceOp(op, result->replacement); return success(); } private: vector::UnrollVectorOptions options; }; - } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), options, benefit); + patterns.add(patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/unroll-transform-op.mlir b/mlir/test/Dialect/Vector/unroll-transform-op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/unroll-transform-op.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -cse --split-input-file | FileCheck %s + +func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match ops{["vector.transfer_read"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.vector.unroll %0 { + unroll_shape = array + } : !pdl.operation +} + +// CHECK-LABEL: func @transfer_read_unroll +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> +// ----- + +func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match ops{["vector.transfer_read"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.vector.unroll %0 { + unroll_shape = array, + unroll_order = array + } : !pdl.operation +} + +// CHECK-LABEL: func @transfer_read_unroll +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> + +// ----- + + +func.func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match ops{["vector.transfer_read"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.vector.unroll %0 { + unroll_shape = array + } : !pdl.operation + %1 = transform.structured.match ops{["vector.transfer_write"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.vector.unroll %1 { + unroll_shape = array + } : !pdl.operation +} + +// CHECK-LABEL: func @transfer_readwrite_unroll +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: return + +// ----- + +func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> { + %t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32> + return %t : vector<2x3x8x4xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + %0 = transform.structured.match ops{["vector.transpose"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.vector.unroll %0 { + unroll_shape = array + } : !pdl.operation +} + +// CHECK-LABEL: func @vector_tranpose +// CHECK: %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32> +// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32> +// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> +// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> +// CHECK: return %[[V7]] : vector<2x3x8x4xf32>