diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -858,6 +858,10 @@ return {rank, rank, rank}; } + /// Return the dimensions of the dest that are omitter to insert a source + /// when the result is rank-reduced. + llvm::SmallBitVector getDroppedDims(); + /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h @@ -12,23 +12,38 @@ #include "mlir/Pass/Pass.h" namespace mlir { +namespace tensor { -#define GEN_PASS_DECL -#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// -/// Creates an instance of `tensor` dialect bufferization pass. +/// Collects a set of patterns to rewrite ops within the tensor dialect. +void populateExpandOpsPatterns(RewritePatternSet &patterns); + +/// Appends patterns for folding tensor aliasing ops into consumer load/store +/// ops into `patterns`. +void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns); + +//===----------------------------------------------------------------------===// +// Passes +//===----------------------------------------------------------------------===// + +/// Creates an instance of the `tensor` subset folding pass. +std::unique_ptr createFoldTensorSubsetOpsPass(); + +/// Creates an instance of the `tensor` dialect bufferization pass. std::unique_ptr createTensorBufferizePass(); //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// -namespace tensor { /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" -} // namespace tensor +} // namespace tensor } // namespace mlir #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td @@ -11,9 +11,20 @@ include "mlir/Pass/PassBase.td" +def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> { + let summary = "Fold tensor subset ops into producer/consumer ops"; + let description = [{ + The pass folds tensor subset ops into producer/consumer ops. + }]; + let constructor = "mlir::tensor::createFoldTensorSubsetOpsPass()"; + let dependentDialects = [ + "AffineDialect", "tensor::TensorDialect", "vector::VectorDialect" + ]; +} + def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> { let summary = "Bufferize the `tensor` dialect"; - let constructor = "mlir::createTensorBufferizePass()"; + let constructor = "mlir::tensor::createTensorBufferizePass()"; } #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h @@ -0,0 +1,55 @@ +//===- ViewLikeInterfaceUtils.h - Utilities for ViewLikeIface --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_UTILS_VIEWLIKEINTERFACEUTILS_H_ +#define MLIR_DIALECT_TENSOR_UTILS_VIEWLIKEINTERFACEUTILS_H_ + +#include "mlir/Support/LLVM.h" + +namespace llvm { +class SmallBitVector; +} // namespace llvm + +namespace mlir { +struct LogicalResult; +class RewriterBase; +class Location; +class OpFoldResult; +class ValueRange; +class Value; + +/// Given the 'indicesVals' of a load/store operation involving an op with +/// offsets and strides, return the indices w.r.t to op with offsets and +/// strides. +/// +/// For example, using `memref.load` and `memref.subview` as an illustration: +/// +/// ``` +/// %0 = ... : memref<12x42xf32> +/// %1 = memref.subview %0[%arg0, %arg1][][%stride1, %stride2] : +/// memref<12x42xf32> to memref<4x4xf32, offset=?, strides=[?, ?]> +/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> +/// ``` +/// +/// is folded into: +/// +/// ``` +/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : +/// memref<12x42xf32> +/// ``` +// TODO: find a better landing pad, ideally closer to ViewLikeInterfaces. This +// does create IR and depends on affine.apply so we don't want to put it in the +// interfaces that should remain free of op dependences. +LogicalResult resolveSourceIndicesOffsetsAndStrides( + RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, + ArrayRef mixedStrides, + const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, + SmallVectorImpl &sourceIndices); +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_UTILS_VIEWLIKEINTERFACEUTILS_H_ diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -249,11 +249,11 @@ /// Returns a new AffineMap with the same number of dims and symbols and one /// less result at `pos`, dropped. - AffineMap dropResult(int64_t pos) { return dropResults({pos}); } + AffineMap dropResult(int64_t pos) const { return dropResults({pos}); } // Returns a new AffineMap with the same number of dims and symbols, but all - // positions in `positions` dropped from results. - AffineMap dropResults(ArrayRef positions) { + // results in `positions` dropped. + AffineMap dropResults(ArrayRef positions) const { SmallVector reverse_sorted_positions = llvm::to_vector(positions); llvm::sort(reverse_sorted_positions, std::greater()); @@ -263,9 +263,13 @@ return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); } + // Returns a new AffineMap with the same number of dims and symbols, but all + // results in `positions` dropped. + AffineMap dropResults(const llvm::SmallBitVector &positions) const; + /// Returns a new AffineMap with the same number of dims and symbols and an /// extra result inserted at `pos`. - AffineMap insertResult(AffineExpr expr, unsigned pos) { + AffineMap insertResult(AffineExpr expr, unsigned pos) const { auto exprs = llvm::to_vector<4>(getResults()); exprs.insert(exprs.begin() + pos, expr); return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); @@ -583,6 +587,12 @@ // by any of the maps in the input array `maps`. llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef maps); +/// Expand `map` to operate on `rank` dims while projecting out the dims in +/// `projectedDimensions`. This amounts to composing `map` with +/// `id(rank).dropResults(projectedDimensions)`. +AffineMap expandDimsToRank(AffineMap map, int64_t rank, + const llvm::SmallBitVector &projectedDimensions); + inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { map.print(os); return os; diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -17,9 +17,13 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -150,70 +154,6 @@ return success(); } -/// Given the 'indices' of an load/store operation where the memref is a result -/// of a subview op, returns the indices w.r.t to the source memref of the -/// subview op. For example -/// -/// %0 = ... : memref<12x42xf32> -/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to -/// memref<4x4xf32, offset=?, strides=[?, ?]> -/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> -/// -/// could be folded into -/// -/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> -static LogicalResult -resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter, - memref::SubViewOp subViewOp, ValueRange indices, - SmallVectorImpl &sourceIndices) { - SmallVector mixedOffsets = subViewOp.getMixedOffsets(); - SmallVector mixedSizes = subViewOp.getMixedSizes(); - SmallVector mixedStrides = subViewOp.getMixedStrides(); - - SmallVector useIndices; - // Check if this is rank-reducing case. Then for every unit-dim size add a - // zero to the indices. - int64_t resultDim = 0; - llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { - if (unusedDims.test(dim)) - useIndices.push_back(rewriter.create(loc, 0)); - else - useIndices.push_back(indices[resultDim++]); - } - if (useIndices.size() != mixedOffsets.size()) - return failure(); - sourceIndices.resize(useIndices.size()); - for (auto index : llvm::seq(0, mixedOffsets.size())) { - SmallVector dynamicOperands; - AffineExpr expr = rewriter.getAffineDimExpr(0); - int64_t numSymbols = 0; - dynamicOperands.push_back(useIndices[index]); - - // Multiply the stride; - if (auto attr = mixedStrides[index].dyn_cast()) { - expr = expr * attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedStrides[index].get()); - expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); - } - - // Add the offset. - if (auto attr = mixedOffsets[index].dyn_cast()) { - expr = expr + attr.cast().getInt(); - } else { - dynamicOperands.push_back(mixedOffsets[index].get()); - expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); - } - Location loc = subViewOp.getLoc(); - OpFoldResult ofr = makeComposedFoldedAffineApply( - rewriter, loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); - sourceIndices[index] = getValueOrCreateConstantIndexOp(rewriter, loc, ofr); - } - return success(); -} - /// Helpers to access the memref operand for each op. template static Value getMemRefOperand(LoadOrStoreOpTy op) { @@ -236,25 +176,6 @@ return op.getDstMemref(); } -/// Given the permutation map of the original -/// `vector.transfer_read`/`vector.transfer_write` operations compute the -/// permutation map to use after the subview is folded with it. -static AffineMapAttr getPermutationMapAttr(MLIRContext *context, - memref::SubViewOp subViewOp, - AffineMap currPermutationMap) { - llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); - SmallVector exprs; - int64_t sourceRank = subViewOp.getSourceType().getRank(); - for (auto dim : llvm::seq(0, sourceRank)) { - if (unusedDims.test(dim)) - continue; - exprs.push_back(getAffineDimExpr(dim, context)); - } - auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); - return AffineMapAttr::get( - currPermutationMap.compose(resultDimToSourceDimMap)); -} - //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -390,6 +311,43 @@ return expandedIndices; } +template +static LogicalResult +preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, + memref::SubViewOp subviewOp) { + static_assert( + !llvm::is_one_of::value, + "must be a vector transfer op"); + if (xferOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); + if (xferOp.getMask()) + return rewriter.notifyMatchFailure(xferOp, "masked transfer"); + if (!subviewOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + xferOp, "non-1 stride insert/extract, may require an extract " + "vector.insert/extract_strided_slice"); + } + return success(); +} + +static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, + Operation *op, + memref::SubViewOp subviewOp) { + return success(); +} + +static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, + vector::TransferReadOp readOp, + memref::SubViewOp subviewOp) { + return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); +} + +static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, + vector::TransferWriteOp writeOp, + memref::SubViewOp subviewOp) { + return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); +} + template LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( OpTy loadOp, PatternRewriter &rewriter) const { @@ -397,7 +355,12 @@ getMemRefOperand(loadOp).template getDefiningOp(); if (!subViewOp) - return failure(); + return rewriter.notifyMatchFailure(loadOp, "not a subview producer"); + + LogicalResult preconditionResult = + preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); + if (failed(preconditionResult)) + return preconditionResult; SmallVector indices(loadOp.getIndices().begin(), loadOp.getIndices().end()); @@ -410,9 +373,13 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp, - indices, sourceIndices))) - return failure(); + if (failed(resolveSourceIndicesOffsetsAndStrides( + rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), + subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, + sourceIndices))) { + return rewriter.notifyMatchFailure(subViewOp, + "could not resolve source indices"); + } llvm::TypeSwitch(loadOp) .Case([&](AffineLoadOp op) { @@ -423,14 +390,13 @@ rewriter.replaceOpWithNewOp( loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); }) - .Case([&](vector::TransferReadOp transferReadOp) { + .Case([&](vector::TransferReadOp op) { rewriter.replaceOpWithNewOp( - transferReadOp, transferReadOp.getVectorType(), - subViewOp.getSource(), sourceIndices, - getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferReadOp.getPermutationMap()), - transferReadOp.getPadding(), - /*mask=*/Value(), transferReadOp.getInBoundsAttr()); + op, op.getVectorType(), subViewOp.getSource(), sourceIndices, + AffineMapAttr::get(expandDimsToRank( + op.getPermutationMap(), subViewOp.getSourceType().getRank(), + subViewOp.getDroppedDims())), + op.getPadding(), /*mask=*/Value(), op.getInBoundsAttr()); }) .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { rewriter.replaceOpWithNewOp( @@ -512,7 +478,12 @@ getMemRefOperand(storeOp).template getDefiningOp(); if (!subViewOp) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "not a subview producer"); + + LogicalResult preconditionResult = + preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); + if (failed(preconditionResult)) + return preconditionResult; SmallVector indices(storeOp.getIndices().begin(), storeOp.getIndices().end()); @@ -525,9 +496,13 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp, - indices, sourceIndices))) - return failure(); + if (failed(resolveSourceIndicesOffsetsAndStrides( + rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), + subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, + sourceIndices))) { + return rewriter.notifyMatchFailure(subViewOp, + "could not resolve source indices"); + } llvm::TypeSwitch(storeOp) .Case([&](AffineStoreOp op) { @@ -542,8 +517,9 @@ .Case([&](vector::TransferWriteOp op) { rewriter.replaceOpWithNewOp( op, op.getValue(), subViewOp.getSource(), sourceIndices, - getPermutationMapAttr(rewriter.getContext(), subViewOp, - op.getPermutationMap()), + AffineMapAttr::get(expandDimsToRank( + op.getPermutationMap(), subViewOp.getSourceType().getRank(), + subViewOp.getDroppedDims())), op.getInBoundsAttr()); }) .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2396,6 +2396,26 @@ }; } // namespace +llvm::SmallBitVector InsertSliceOp::getDroppedDims() { + ArrayRef resultShape = getType().getShape(); + SmallVector mixedSizes = getMixedSizes(); + llvm::SmallBitVector droppedDims(mixedSizes.size()); + unsigned shapePos = 0; + for (const auto &size : enumerate(mixedSizes)) { + std::optional sizeVal = getConstantIntValue(size.value()); + // If the size is not 1, or if the current matched dimension of the result + // is the same static shape as the size value (which is 1), then the + // dimension is preserved. + if (!sizeVal || *sizeVal != 1 || + (shapePos < resultShape.size() && resultShape[shapePos] == 1)) { + shapePos++; + continue; + } + droppedDims.set(size.index()); + } + return droppedDims; +} + void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -53,6 +53,6 @@ }; } // namespace -std::unique_ptr mlir::createTensorBufferizePass() { +std::unique_ptr mlir::tensor::createTensorBufferizePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp FoldIntoPackAndUnpackPatterns.cpp + FoldTensorSubsetOps.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp SplitPaddingPatterns.cpp @@ -29,4 +30,5 @@ MLIRTensorDialect MLIRTilingInterface MLIRTransforms + MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -0,0 +1,177 @@ +//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold tensor subset ops with producer / consumers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace tensor { +#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS +#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" +} // namespace tensor +} // namespace mlir + +using namespace mlir; + +static Value getTensorOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getTensorOperand(tensor::InsertSliceOp op) { + return op.getSource(); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { +/// Merge extract_slice operation with load/transferRead operation. +class TransferReadOfExtractSliceOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override; +}; + +/// Merge insert_slice operation with store/transferWriteOp operation. +class InsertSliceOfTransferWriteOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +template +static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp( + RewriterBase &rewriter, XferOp xferOp, + ExtractOrInsertOp extractOrInsertSliceOp) { + if (xferOp.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); + if (xferOp.getMask()) + return rewriter.notifyMatchFailure(xferOp, "masked transfer"); + if (!extractOrInsertSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + xferOp, "non-1 stride insert/extract, may require an extract " + "vector.insert/extract_strided_slice"); + } + return success(); +} + +LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite( + vector::TransferReadOp readOp, PatternRewriter &rewriter) const { + auto extractSliceOp = + getTensorOperand(readOp).getDefiningOp(); + if (!extractSliceOp) + return rewriter.notifyMatchFailure(readOp, "not an extract_slice"); + + LogicalResult preconditionResult = + preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp, + extractSliceOp); + if (failed(preconditionResult)) + return preconditionResult; + + SmallVector indices(readOp.getIndices().begin(), + readOp.getIndices().end()); + SmallVector sourceIndices; + if (failed(resolveSourceIndicesOffsetsAndStrides( + rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), + indices, sourceIndices))) { + return rewriter.notifyMatchFailure(readOp, + "could not resolve source indices"); + } + + rewriter.replaceOpWithNewOp( + readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices, + AffineMapAttr::get(expandDimsToRank( + readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(), + extractSliceOp.getDroppedDims())), + readOp.getPadding(), + /*mask=*/Value(), readOp.getInBoundsAttr()); + + return success(); +} + +LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( + tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { + auto writeOp = getTensorOperand(insertSliceOp) + .template getDefiningOp(); + if (!writeOp) + return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write"); + + LogicalResult preconditionResult = + preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp, + insertSliceOp); + if (failed(preconditionResult)) + return preconditionResult; + + SmallVector indices(writeOp.getIndices().begin(), + writeOp.getIndices().end()); + SmallVector sourceIndices; + if (failed(resolveSourceIndicesOffsetsAndStrides( + rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), + indices, sourceIndices))) { + return rewriter.notifyMatchFailure(writeOp, + "could not resolve source indices"); + } + + rewriter.replaceOpWithNewOp( + insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices, + AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(), + insertSliceOp.getDestType().getRank(), + insertSliceOp.getDroppedDims())), + writeOp.getInBoundsAttr()); + + return success(); +} + +void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct FoldTensorSubsetOpsPass final + : public tensor::impl::FoldTensorSubsetOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void FoldTensorSubsetOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + tensor::populateFoldTensorSubsetOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +std::unique_ptr tensor::createFoldTensorSubsetOpsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTensorUtils Utils.cpp + ViewLikeInterfaceUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.cpp @@ -0,0 +1,51 @@ +//===- ViewListInterfaceUtils.cpp - Utilities for ViewLikeInterface -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for ViewLikeInterface. These need to be +// attached to a dialect as they produce affine and arith ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/Utils/ViewLikeInterfaceUtils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" + +using namespace mlir; + +LogicalResult mlir::resolveSourceIndicesOffsetsAndStrides( + RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, + ArrayRef mixedStrides, + const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, + SmallVectorImpl &sourceIndices) { + OpFoldResult zero = rewriter.getIndexAttr(0); + + // For each dimension that is rank-reduced, add a zero to the indices. + int64_t indicesDim = 0; + SmallVector indices; + for (auto dim : llvm::seq(0, mixedOffsets.size())) { + OpFoldResult ofr = + (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++]; + indices.push_back(ofr); + } + + sourceIndices.resize(indices.size()); + sourceIndices.clear(); + for (auto [offset, index, stride] : + llvm::zip_equal(mixedOffsets, indices, mixedStrides)) { + AffineExpr off, idx, str; + bindSymbols(rewriter.getContext(), off, idx, str); + OpFoldResult ofr = makeComposedFoldedAffineApply( + rewriter, loc, AffineMap::get(0, 3, off + idx * str), + {offset, index, stride}); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return success(); +} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3733,6 +3733,8 @@ /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} /// : tensor, vector<4x5xf32> /// ``` +// TODO: this is brittle and should be deprecated in favor of a more general +// pattern that applies on-demand. struct FoldExtractSliceIntoTransferRead : public OpRewritePattern { public: @@ -3883,9 +3885,13 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + // clang-format off + results.add < + // TODO: this is brittle and should be deprecated in favor of a + // more general pattern that applies on-demand. + FoldExtractSliceIntoTransferRead, + TransferReadAfterWriteToBroadcast>(context); + // clang-format on } //===----------------------------------------------------------------------===// @@ -4235,6 +4241,8 @@ /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} /// : vector<4x5xf32>, tensor /// ``` +// TODO: this is brittle and should be deprecated in favor of a more general +// pattern that applies on-demand. struct FoldInsertSliceIntoTransferWrite : public OpRewritePattern { public: @@ -4417,8 +4425,13 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + // clang-format on } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/AffineMap.h" #include "AffineMapDetail.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" @@ -15,8 +16,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include #include #include #include @@ -467,6 +470,15 @@ return AffineMap::inferFromExprList(newResults).front(); } +AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const { + auto exprs = llvm::to_vector<4>(getResults()); + // TODO: this is a pretty terrible API .. is there anything better? + for (auto pos = positions.find_last(); pos != -1; + pos = positions.find_prev(pos)) + exprs.erase(exprs.begin() + pos); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); +} + AffineMap AffineMap::compose(AffineMap map) const { assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); // Prepare `map` by concatenating the symbols and rewriting its exprs. @@ -808,6 +820,14 @@ return numSymbolsBitVector; } +AffineMap +mlir::expandDimsToRank(AffineMap map, int64_t rank, + const llvm::SmallBitVector &projectedDimensions) { + auto id = AffineMap::getMultiDimIdentityMap(rank, map.getContext()); + AffineMap proj = id.dropResults(projectedDimensions); + return map.compose(proj); +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -6,7 +6,7 @@ return %1 : f32 } // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> // CHECK: func @fold_static_stride_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -25,7 +25,7 @@ %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>> return %1 : f32 } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)> // CHECK: func @fold_dynamic_stride_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -34,8 +34,8 @@ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]] +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]] // CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]] // ----- @@ -66,7 +66,7 @@ memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>> return } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)> // CHECK: func @fold_dynamic_stride_subview_with_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -75,8 +75,8 @@ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]] +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[ARG5]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[ARG6]]] // CHECK: memref.store %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]] // ----- @@ -85,7 +85,7 @@ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> vector { %f1 = arith.constant 1.0 : f32 - %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref> + %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref> %1 = vector.transfer_read %0[], %f1 : memref>, vector return %1 : vector } @@ -100,22 +100,14 @@ func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> { %f1 = arith.constant 1.0 : f32 + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, strided<[?, ?], offset: ?>>, vector<4xf32> return %1 : vector<4xf32> } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> // CHECK: func @fold_subview_with_transfer_read -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]] -// CHECK: vector.transfer_read %[[ARG0]][%[[I1]], %[[I2]]] +// Can't fold this atm since we don't emit the proper vector.extract_strided_slice. +// CHECK: memref.subview // ----- @@ -123,7 +115,7 @@ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %v : vector) { %f1 = arith.constant 1.0 : f32 - %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref> + %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref> vector.transfer_write %v, %0[] {in_bounds = []} : vector, memref> return } @@ -143,18 +135,9 @@ vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>> return } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> // CHECK: func @fold_static_stride_subview_with_transfer_write -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]] -// CHECK: vector.transfer_write %{{.+}}, %[[ARG0]][%[[I1]], %[[I2]]] +// Can't fold this atm since we don't emit the proper vector.extract_strided_slice. +// CHECK: memref.subview // ----- @@ -168,7 +151,7 @@ %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>> return %1 : f32 } -// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * s2)> // CHECK: func @fold_rank_reducing_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -187,10 +170,10 @@ // CHECK-SAME: %[[ARG14:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG15:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG7]], %[[ARG1]], %[[ARG13]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG9]], %[[ARG3]], %[[ARG14]]] -// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG10]], %[[ARG4]], %[[ARG15]]] -// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG11]], %[[ARG5]], %[[ARG16]]] +// CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]] +// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]] +// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]] // CHECK: memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]] // ----- diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir @@ -0,0 +1,262 @@ +// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s + +func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice( + %arg0 : tensor, + %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index, + %arg6 : index) -> vector<4xf32> { + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1] + : tensor to + tensor + %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]} + : tensor, vector<4xf32> + return %1 : vector<4xf32> +} +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fold_vector_transfer_read_with_rank_reduced_extract_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[$MAP1]]()[%[[ARG1]], %[[ARG5]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG6]]] +// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]], %{{.*}} : tensor, + %i1: index, %i2: index, %i3: index, %i4: index) -> vector<4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %f0 = arith.constant 0.000000e+00 : f32 + + // Can't fold this atm since we don' emit the proper vector.extract_strided_slice. +// CHECK: tensor.extract_slice + %0 = tensor.extract_slice %src[0, %i1, %i2, %i3] [1, 4, 1, 4] [2, 3, 4, 5] : tensor<1x8x8x8xf32> to tensor<1x4x4xf32> + %1 = vector.transfer_read %0[%c1, %i4, %c2], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32> + return %1 : vector<4xf32> +} + +// ----- + +// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)> + +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32> + return %1 : vector<5x6xf32> +} +// ----- + +func.func @fold_extract_slice_with_transfer_read_0d( + %arg0 : tensor<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) + -> vector { + %f1 = arith.constant 1.0 : f32 + %0 = tensor.extract_slice %arg0[%arg1, %arg2][1, 1][1, 1] : tensor<12x32xf32> to tensor + %1 = vector.transfer_read %0[], %f1 : tensor, vector + return %1 : vector +} +// CHECK: func @fold_extract_slice_with_transfer_read_0d +// CHECK-SAME: %[[T:[a-zA-Z0-9_]+]]: tensor<12x32xf32> +// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index +// CHECK: vector.transfer_read %[[T]][%[[SZ0]], %[[SZ1]]] + +// ----- + +// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)> + +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s1]]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32> + return %1 : vector<6xf32> +} + +// ----- + +// CHECK-DAG: #[[$ADD_3:.+]] = affine_map<()[s0] -> (s0 + 3)> + +// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index +// CHECK: %[[add:.*]] = affine.apply #[[$ADD_3]]()[%[[s1]]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor, vector<5x6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-DAG: #[[$ADD_4:.+]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: func @transfer_read_of_extract_slice_swappy_rank_reducing( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + +// CHECK-NOT: extract_slice +// CHECK: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = affine.apply #[[$ADD_4]]()[%[[s2]]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[s1]], %[[add]]] +// CHECK-SAME: permutation_map = #[[$d0d2]] +// CHECK-SAME: tensor, vector<5x6xf32> + %0 = tensor.extract_slice %t[5, %s1, %s2] [%s2, 1, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + + return %1 : vector<5x6xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> + +// CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice( + %arg0 : tensor, + %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, + %arg5: index, %arg6 : index, %arg7 : index, + %st : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + +// CHECK-NOT: insert_slice +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]] +// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor, tensor + %1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1] + : tensor into tensor + return %1 : tensor +} + +// ----- + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice( + %arg0 : tensor, + %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, + %arg5: index, %arg6 : index, %arg7 : index, + %st : tensor) -> tensor { + %cst = arith.constant 0.0 : f32 + + // CHECK-NOT: insert_slice + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]] + // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]] + // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]] + // CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor, tensor + %1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1] + : tensor into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +func.func @insert_slice_of_transfer_write(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + + // CHECK-NOT: insert_slice +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +func.func @insert_slice_of_transfer_write_swappy_rank_extending( + %t1 : tensor, %v : vector<5x6xf32>, + %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + +// CHECK-NOT: insert_slice +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] +// CHECK-SAME: {in_bounds = [true, true], permutation_map = #[[$d0d2]]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( +// CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor +// CHECK: return %[[r]] +func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +}