diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -848,7 +848,8 @@ DefaultValuedAttr:$pack_paddings, DefaultValuedAttr< TypedArrayAttrBase, - "{}">:$transpose_paddings); + "{}">:$transpose_paddings, + DefaultValuedAttr:$copy_back); let results = (outs TransformHandleTypeInterface:$transformed); let assemblyFormat = diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -370,10 +370,12 @@ /// and `packPaddings` to set padding value and nofold attribute of the created /// tensor::PadOps, respectively. Update `paddedOp` to the cloned operation with /// statically shaped `paddingDimensions` and return the extracted dynamically -/// shaped results. If padding fails, return failure. +/// shaped results. If padding fails, return failure. If `copyBack` is set, the +/// unpadded result is copied back into the original destination tensor. FailureOr> rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - const LinalgPaddingOptions &options, LinalgOp &paddedOp); + const LinalgPaddingOptions &options, LinalgOp &paddedOp, + bool copyBack); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1601,7 +1601,7 @@ options.paddingValues = paddingValues; options.packPaddings = packPaddings; FailureOr> result = - rewriteAsPaddedOp(rewriter, target, options, paddedOp); + rewriteAsPaddedOp(rewriter, target, options, paddedOp, getCopyBack()); if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -138,7 +139,7 @@ FailureOr> linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, - LinalgOp &paddedOp) { + LinalgOp &paddedOp, bool copyBack) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -197,7 +198,21 @@ loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } - return paddedSubtensorResults; + + if (!copyBack) + return paddedSubtensorResults; + + // Copy back unpadded results to the original destination (i.e., inits of the + // linalg op), so that the destination buffer of the computation does not + // change. If the padding folds away, this will materizalize as a memcpy + // between two identical buffers, which will then also fold away. + SmallVector copiedBack; + for (auto it : + llvm::zip(paddedSubtensorResults, opToPad.getDpsInitOperands())) { + copiedBack.push_back(rewriter.create( + loc, std::get<0>(it), std::get<1>(it)->get())); + } + return copiedBack; } FailureOr @@ -209,8 +224,8 @@ // Pad the operation. LinalgOp paddedOp; - FailureOr> newResults = - rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp); + FailureOr> newResults = rewriteAsPaddedOp( + rewriter, linalgOp, options, paddedOp, /*copyBack=*/false); if (failed(newResults)) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -19,7 +19,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op // In this case, the pad op is actually empty: we only tile the first dimension @@ -54,7 +55,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[2] @@ -96,7 +98,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -140,7 +143,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -183,7 +187,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[2] diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -241,3 +241,54 @@ pack_paddings=[1, 1, 1] } : (!transform.any_op) -> !transform.any_op } + +// ----- + +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + +// CHECK-LABEL: @pack_everything +func.func @pack_everything(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[T0:.*]] = tensor.extract_slice % + // CHECK: %[[T1:.*]] = tensor.extract_slice % + // CHECK: %[[T2:.*]] = tensor.extract_slice % + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK-DAG: %[[CST:.*]] = arith.constant 0. + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] nofold + // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] nofold + // CHECK: %[[PAD2:.*]] = tensor.pad %[[T2]] nofold + + // CHECK: %[[T5:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<4x7xf32>, tensor<7x5xf32>) + // CHECK-SAME: outs(%[[PAD2]] : tensor<4x5xf32>) + + // Get unpadded result (no-op in this example). + // CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]] + // Copy back result to the original buffer, so that the destination of the + // computation does not change. + // CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + + // CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]] into %{{.*}} + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 1] + } : (!transform.any_op) -> !transform.any_op +}