diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -292,11 +292,22 @@ vectorizationPatterns.add( funcOp.getContext(), /*benefit=*/2); - if (vectorizePadding) { - linalg::populatePadTensorOpVectorizationPatterns(vectorizationPatterns); - } + TransferReadOp::getCanonicalizationPatterns(vectorizationPatterns, + funcOp.getContext()); + TransferWriteOp::getCanonicalizationPatterns(vectorizationPatterns, + funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns)); + + // Apply the pad tensor op vectorization separately to avoid running the + // GenericPadTensorOpVectorizationPattern too early. + // TODO: Improve once we have better infrastructure to control pattern + // application. + if (vectorizePadding) { + RewritePatternSet patterns(funcOp.getContext()); + linalg::populatePadTensorOpVectorizationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } } LinalgVectorizationOptions options; diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -3,6 +3,7 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP // CHECK-INTRINSIC: func @matmul( // CHECK-OUTER: func @matmul( @@ -74,3 +75,18 @@ %1 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%0: tensor<72x72xf32>) -> tensor<72x72xf32> return %1 : tensor<72x72xf32> } + +// ----- + +// CHECK-DECOMP: func @conv( +func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg2) : f32, tensor<8x16x15x64xf32> -> tensor<8x16x15x64xf32> + + // Check the conv is padded by a rank-reducing vector transfer op pair. + // CHECK-DECOMP: vector.transfer_read {{.*}}: tensor<1x1x?x8xf32>, vector<1x8x8xf32> + // CHECK-DECOMP: vector.outerproduct + // CHECK-DECOMP: vector.transfer_write {{.*}}: vector<1x8x32xf32>, tensor<1x1x?x32xf32> + %1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<8x18x17x32xf32>, tensor<3x3x32x64xf32>) outs(%0 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> + return %1 : tensor<8x16x15x64xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -120,6 +120,10 @@ *this, "vectorize", llvm::cl::desc("Rewrite the linalg op as a vector operation."), llvm::cl::init(false)}; + Option vectorizePadding{ + *this, "vectorize-padding", + llvm::cl::desc("Rewrite pad tensor ops as vector operations."), + llvm::cl::init(false)}; Option splitVectorTransfersTo{ *this, "split-transfers", llvm::cl::desc( @@ -186,7 +190,7 @@ .decomposeIf(decompose) .generalizeIf(generalize, "") .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) - .vectorizeIf(vectorize, "") + .vectorizeIf(vectorize, "", nullptr, vectorizePadding) .vectorLowering( LinalgVectorLoweringOptions() .setVectorTransformsOptions(