diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -40,15 +40,16 @@ /// Analysis class to support PadTensorOp hoisting across multiple enclosing /// loops. The failure conditions are: /// 1. Pad op has a use that is not an input of a LinalgOp. -/// 2. There is no immediately enclosing scf::ForOp. -/// 3. The backward slice from the pad op to the scf::ForOp to hoist above +/// 2. Pad op does not have a constant padding value. +/// 3. There is no immediately enclosing scf::ForOp. +/// 4. The backward slice from the pad op to the scf::ForOp to hoist above /// contains an unknown op with a region. -/// 4. The backward slice from the pad op to the scf::ForOp to hoist above is +/// 5. The backward slice from the pad op to the scf::ForOp to hoist above is /// empty. -/// 5. The source tensor of pad op is not defined by an extract slice op. -/// 6. The source tensor of the extract slice op is not defined outside of +/// 6. The source tensor of pad op is not defined by an extract slice op. +/// 7. The source tensor of the extract slice op is not defined outside of /// the outermost enclosing scf::ForOp. -/// 7. There is no enclosing scf::ForOp that indexes the padded data. +/// 8. There is no enclosing scf::ForOp that indexes the padded data. /// Other cases succeed and will trigger hoisting of the pad op. struct HoistingAnalysis { HoistingAnalysis(PadTensorOp padTensorOp, int numLoops); @@ -183,6 +184,16 @@ return; } + // Check the region of `padTensorOp` depends on a constant only. Adding + // hoisting support for arbitrary padding regions would require cloning all + // dependencies captured by the padding region. + Value paddingValue = padTensorOp.getConstantPaddingValue(); + if (!paddingValue || + !isa_and_nonnull(paddingValue.getDefiningOp())) { + LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> skip\n"); + return; + } + // Get all the ops in the backwards slice starting from `padTensorOp` and that // are dominated by the outermost enclosing loop. DominanceInfo domInfo(outermostEnclosingForOp); diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -358,3 +358,81 @@ } return %0 : tensor<24x25xf32> } + +// ----- + +#map0 = affine_map<(d0) -> (5, -d0 + 24)> +#map1 = affine_map<(d0) -> (7, -d0 + 25)> +#map2 = affine_map<(d0) -> (-d0 + 5)> +#map3 = affine_map<(d0) -> (-d0 + 7)> + +// CHECK: non_constant_padding +// CHECK-DOUBLE: non_constant_padding +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +func @non_constant_padding(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c5 = arith.constant 5 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg3 = %c0 to %c24 step %c5 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK-NEXT: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + + // CHECK-NEXT: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + %3 = affine.min #map0(%arg3) + %4 = tensor.extract_slice %arg0[%arg3, %arg7] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor + %5 = affine.min #map1(%arg5) + %6 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32> + %7 = tensor.extract_slice %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor + %8 = affine.apply #map2(%3) + + // Check the padding with a non constant padding value is not hoisted. + // CHECK: %[[T0:.*]] = linalg.pad_tensor + // CHECK: %[[V0:.*]] = arith.index_cast + // CHECK: %[[V1:.*]] = arith.sitofp %[[V0]] + // CHECK: linalg.yield %[[V1]] + %9 = linalg.pad_tensor %4 nofold low[%c0, %c0] high[%8, %c0] { + ^bb0(%arg9: index, %arg10: index): // no predecessors + %17 = arith.index_cast %arg7 : index to i32 + %18 = arith.sitofp %17 : i32 to f32 + linalg.yield %18 : f32 + } : tensor to tensor<5x6xf32> + %10 = affine.apply #map3(%5) + + // Check the padding with a non constant op padding is not hoisted. + // CHECK: %[[V2:.*]] = tensor.extract %[[ARG1]][%[[IV2]], %[[IV1]] + // CHECK: %[[T1:.*]] = linalg.pad_tensor + // CHECK: linalg.yield %[[V2]] + %11 = tensor.extract %arg1[%arg7, %arg5] : tensor<12x25xf32> + %12 = linalg.pad_tensor %6 nofold low[%c0, %c0] high[%c0, %10] { + ^bb0(%arg9: index, %arg10: index): // no predecessors + linalg.yield %11 : f32 + } : tensor<6x?xf32> to tensor<6x7xf32> + %13 = linalg.pad_tensor %7 low[%c0, %c0] high[%8, %10] { + ^bb0(%arg9: index, %arg10: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<5x7xf32> + + // CHECK: = linalg.matmul ins(%[[T0]], %[[T1]] + %14 = linalg.matmul ins(%9, %12 : tensor<5x6xf32>, tensor<6x7xf32>) outs(%13 : tensor<5x7xf32>) -> tensor<5x7xf32> + %15 = tensor.extract_slice %14[0, 0] [%3, %5] [1, 1] : tensor<5x7xf32> to tensor + %16 = tensor.insert_slice %15 into %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor into tensor<24x25xf32> + scf.yield %16 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} +