diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -15,28 +15,41 @@ #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; namespace { -/// Merges subview operation with load operation. -class LoadOpOfSubViewFolder final : public OpRewritePattern { +/// Merges subview operation with load/transferRead operation. +template +class LoadOpOfSubViewFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(LoadOp loadOp, + LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; + +private: + void replaceOp(OpTy loadOp, SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const; }; -/// Merges subview operation with store operation. -class StoreOpOfSubViewFolder final : public OpRewritePattern { +/// Merges subview operation with store/transferWriteOp operation. +template +class StoreOpOfSubViewFolder final : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(StoreOp storeOp, + LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; + +private: + void replaceOp(OpTy StoreOp, SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const; }; } // namespace @@ -85,13 +98,14 @@ } //===----------------------------------------------------------------------===// -// Folding SubViewOp and LoadOp. +// Folding SubViewOp and LoadOp/TransferReadOp. //===----------------------------------------------------------------------===// +template LogicalResult -LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const { - auto subViewOp = loadOp.memref().getDefiningOp(); +LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const { + auto subViewOp = loadOp.memref().template getDefiningOp(); if (!subViewOp) { return failure(); } @@ -100,19 +114,36 @@ loadOp.indices(), sourceIndices))) return failure(); + replaceOp(loadOp, subViewOp, sourceIndices, rewriter); + return success(); +} + +template <> +void LoadOpOfSubViewFolder::replaceOp(LoadOp loadOp, + SubViewOp subViewOp, + ArrayRef sourceIndices, + PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); - return success(); +} + +template <> +void LoadOpOfSubViewFolder::replaceOp( + vector::TransferReadOp loadOp, SubViewOp subViewOp, + ArrayRef sourceIndices, PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices); } //===----------------------------------------------------------------------===// -// Folding SubViewOp and StoreOp. +// Folding SubViewOp and StoreOp/TransferWriteOp. //===----------------------------------------------------------------------===// +template LogicalResult -StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const { - auto subViewOp = storeOp.memref().getDefiningOp(); +StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, + PatternRewriter &rewriter) const { + auto subViewOp = storeOp.memref().template getDefiningOp(); if (!subViewOp) { return failure(); } @@ -121,9 +152,25 @@ storeOp.indices(), sourceIndices))) return failure(); + replaceOp(storeOp, subViewOp, sourceIndices, rewriter); + return success(); +} + +template <> +void StoreOpOfSubViewFolder::replaceOp( + StoreOp storeOp, SubViewOp subViewOp, ArrayRef sourceIndices, + PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), subViewOp.source(), sourceIndices); - return success(); +} + +template <> +void StoreOpOfSubViewFolder::replaceOp( + vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp, + ArrayRef sourceIndices, PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(), + sourceIndices); } //===----------------------------------------------------------------------===// @@ -132,7 +179,10 @@ void mlir::populateStdLegalizationPatternsForSPIRVLowering( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert(context); + patterns.insert, + LoadOpOfSubViewFolder, + StoreOpOfSubViewFolder, + StoreOpOfSubViewFolder>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir @@ -62,3 +62,37 @@ store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return } + +// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read +// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index +func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> { + // CHECK-NOT: subview + // CHECK: [[C2:%.*]] = constant 2 : index + // CHECK: [[C3:%.*]] = constant 3 : index + // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index + // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index + // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index + // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index + // CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + %f0 = constant 0.0 : f32 + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write +// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32> +func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) { + // CHECK-NOT: subview + // CHECK: [[C2:%.*]] = constant 2 : index + // CHECK: [[C3:%.*]] = constant 3 : index + // CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index + // CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index + // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index + // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index + // CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]> + return +}