diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -180,9 +180,15 @@ if (failed(paddingValue)) return failure(hasDynamicShape); + OpOperand *operand = opOperand; + while (auto linalgOp = operand->get().getDefiningOp()) { + OpResult result = operand->get().cast(); + operand = linalgOp.getOutputOperand(result.getResultNumber()); + } + // Cannot construct a static bounding box if the operand is not defined by an // ExtractSliceOp. - auto sliceOp = opOperand->get().getDefiningOp(); + auto sliceOp = operand->get().getDefiningOp(); if (!sliceOp) return failure(hasDynamicShape); diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=MATMUL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=FILL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=2,1,0 run-enable-pass=false pad-anchor-op-only" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=FILL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic pad pack-paddings=2,1,0 run-enable-pass=false pad-anchor-op-only" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=GENERIC // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 pad-inputs-only run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY // MATMUL-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<()[s0] -> (7, -s0 + 12)> @@ -171,11 +172,13 @@ %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor - // Check both fill operations are padded by the same pad tensor operation. + // Check both fill operations are padded by the same source tensor operation. // FILL: %[[T0:.*]] = tensor.pad // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) - // FILL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) - // FILL: = tensor.extract_slice %[[T2]] + // FILL: %[[T2:.*]] = tensor.extract_slice %[[T1]] + // FILL: %[[T3:.*]] = tensor.pad %[[T2]] + // FILL: %[[T4:.*]] = linalg.fill(%{{.*}}, %[[T3]]) + // FILL: = tensor.extract_slice %[[T4]] %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor %2 = linalg.fill(%cst, %1) : f32, tensor -> tensor return %2 : tensor @@ -472,3 +475,54 @@ %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32> return %1 : tensor<1x?x?xf32> } + +// ----- + +// GENERIC: func @matmul_bias_add( +// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<25x49xf32>, +// GENERIC-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<49x33xf32>, +// GENERIC-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<33xf32>, +// GENERIC-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<25x33xf32>, +// GENERIC-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<25x33xf32>) +func @matmul_bias_add(%arg0: tensor<25x49xf32>, + %arg1: tensor<49x33xf32>, + %arg2: tensor<33xf32>, + %arg3: tensor<25x33xf32>, + %arg4: tensor<25x33xf32>) -> tensor<25x33xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c33 = arith.constant 33 : index + %c25 = arith.constant 25 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %0 = scf.for %arg5 = %c0 to %c25 step %c8 iter_args(%arg6 = %arg4) -> (tensor<25x33xf32>) { + %1 = affine.min affine_map<(d0) -> (-d0 + 25, 8)>(%arg5) + %2 = tensor.extract_slice %arg0[%arg5, 0] [%1, 49] [1, 1] : tensor<25x49xf32> to tensor + %3 = affine.min affine_map<(d0) -> (8, -d0 + 25)>(%arg5) + %4 = scf.for %arg7 = %c0 to %c33 step %c16 iter_args(%arg8 = %arg6) -> (tensor<25x33xf32>) { + %5 = affine.min affine_map<(d0) -> (-d0 + 33, 16)>(%arg7) + %6 = tensor.extract_slice %arg1[0, %arg7] [49, %5] [1, 1] : tensor<49x33xf32> to tensor<49x?xf32> + %7 = tensor.extract_slice %arg3[%arg5, %arg7] [%1, %5] [1, 1] : tensor<25x33xf32> to tensor + %8 = linalg.fill(%cst, %7) : f32, tensor -> tensor + %9 = linalg.matmul ins(%2, %6 : tensor, tensor<49x?xf32>) outs(%8 : tensor) -> tensor + %10 = affine.min affine_map<(d0) -> (16, -d0 + 33)>(%arg7) + %11 = tensor.extract_slice %arg2[%arg7] [%10] [1] : tensor<33xf32> to tensor + %12 = tensor.extract_slice %arg8[%arg5, %arg7] [%3, %10] [1, 1] : tensor<25x33xf32> to tensor + + // GENERIC: %[[T0:.+]] = linalg.matmul + // GENERIC: %[[PAD:.+]] = tensor.pad %[[T0]] + // GENERIC: %{{.+}} = linalg.generic + // GENERIC-SAME: ins(%[[PAD]] + + %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %11 : tensor, tensor) outs(%12 : tensor) { + ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): + %15 = arith.addf %arg9, %arg10 : f32 + linalg.yield %15 : f32 + } -> tensor + %14 = tensor.insert_slice %13 into %arg8[%arg5, %arg7] [%3, %10] [1, 1] : tensor into tensor<25x33xf32> + scf.yield %14 : tensor<25x33xf32> + } + scf.yield %4 : tensor<25x33xf32> + } + return %0 : tensor<25x33xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -101,6 +101,10 @@ *this, "pad-inputs-only", llvm::cl::desc("Only pad input operands when test-pad-pattern"), llvm::cl::init(false)}; + Option padAnchorOpOnly{ + *this, "pad-anchor-op-only", + llvm::cl::desc("Only pad anchor op operands when test-pad-pattern"), + llvm::cl::init(false)}; ListOption packPaddings{ *this, "pack-paddings", llvm::cl::desc("Operand packing flags when test-pad-pattern."), @@ -200,7 +204,8 @@ LinalgPromotionOptions() .setAlignment(16) .setUseFullTileBuffersByDefault(registerPromoteFullTile)) - .padIf(pad, "", std::move(paddingOptions)) + .padIf(pad, padAnchorOpOnly ? anchorOpName : std::string(), + std::move(paddingOptions)) .decomposeIf(decompose) .generalizeIf(generalize, "") .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange)