diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -181,9 +181,16 @@ if (failed(paddingValue)) return failure(hasDynamicShape); - // Cannot construct a static bounding box if the operand is not defined by an - // ExtractSliceOp. - auto sliceOp = opOperand->get().getDefiningOp(); + // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. + OpOperand *currOpOperand = opOperand; + while (auto linalgOp = currOpOperand->get().getDefiningOp()) { + OpResult result = currOpOperand->get().cast(); + currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); + } + + // Cannot construct a static bounding box if the `currOpOperand` is not + // defined by an ExtractSliceOp. + auto sliceOp = currOpOperand->get().getDefiningOp(); if (!sliceOp) return failure(hasDynamicShape); diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -1,6 +1,7 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATMUL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=FILL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 pad-inputs-only run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=MATMUL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,1 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,0 run-enable-pass=false" -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL-MATMUL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 pad-inputs-only run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY // MATMUL-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<()[s0] -> (7, -s0 + 12)> // MATMUL-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<()[s0] -> (-s0 + 7)> @@ -163,21 +164,45 @@ #map0 = affine_map<()[s0] -> (64, s0)> -// FILL: pad_multiple -// FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<64x64xf32> +// FILL-MATMUL: pad_multiple +// FILL-MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<64x64xf32> func @pad_multiple(%arg0: tensor<64x64xf32>, - %iv0 : index) -> tensor { + %iv0 : index) -> tensor { %cst = arith.constant 0.0 : f32 %size = affine.min #map0()[%iv0] + + // FILL-MATMUL: %[[T0:.*]] = tensor.extract_slice %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor - // Check both fill operations are padded by the same pad tensor operation. - // FILL: %[[T0:.*]] = tensor.pad - // FILL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] - // FILL: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] - // FILL: = tensor.extract_slice %[[T2]] + // Check the two operations are padded by the same pad tensor operation. + // FILL-MATMUL: %[[T1:.*]] = tensor.pad %[[T0]] + // FILL-MATMUL: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]] + // FILL-MATMUL: %[[T3:.*]] = linalg.matmul {{.*}} outs(%[[T2]] + // FILL-MATMUL: = tensor.extract_slice %[[T3]] + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = linalg.matmul ins(%0, %0 : tensor, tensor) outs(%1 : tensor) -> tensor + return %2 : tensor +} + +// ----- + +#map0 = affine_map<()[s0] -> (64, s0)> + +// MATMUL: pad_chain +// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<64x64xf32> +func @pad_chain(%arg0: tensor<64x64xf32>, + %iv0 : index) -> tensor { + %cst = arith.constant 0.0 : f32 + %size = affine.min #map0()[%iv0] + %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + + // Check the matmul at the end of the use-def chain is padded. + // MATMUL: %[[T0:.*]] = linalg.fill + // MATMUL: %[[T1:.*]] = tensor.pad %[[T0]] + // MATMUL: %[[T2:.*]] = linalg.matmul {{.*}} outs(%[[T1]] + // MATMUL: = tensor.extract_slice %[[T2]] %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %2 = linalg.matmul ins(%0, %0 : tensor, tensor) outs(%1 : tensor) -> tensor return %2 : tensor } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -183,6 +183,7 @@ LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit) { + std::string anchorOpNameOrWildcard = fuse ? "" : anchorOpName.getValue(); CodegenStrategy strategy; strategy .tileAndFuseIf(fuse && !tileSizes.empty(), anchorOpName, @@ -198,11 +199,11 @@ LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) - .padIf(pad, "", std::move(paddingOptions)) + .padIf(pad, anchorOpNameOrWildcard, std::move(paddingOptions)) .decomposeIf(decompose) - .generalizeIf(generalize, "") + .generalizeIf(generalize, anchorOpNameOrWildcard) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) - .vectorizeIf(vectorize, "", nullptr, vectorizePadding) + .vectorizeIf(vectorize, anchorOpNameOrWildcard, nullptr, vectorizePadding) .vectorLowering( LinalgVectorLoweringOptions() .setVectorTransformsOptions(