diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -848,7 +848,8 @@ DefaultValuedAttr:$pack_paddings, DefaultValuedAttr< TypedArrayAttrBase, - "{}">:$transpose_paddings); + "{}">:$transpose_paddings, + DefaultValuedAttr:$copy_back); let results = (outs TransformHandleTypeInterface:$transformed); let assemblyFormat = diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -370,13 +370,13 @@ /// and `packPaddings` to set padding value and nofold attribute of the created /// tensor::PadOps, respectively. Update `paddedOp` to the cloned operation with /// statically shaped `paddingDimensions` and return the extracted dynamically -/// shaped results. If padding fails, return failure. -FailureOr> -rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp); +/// shaped results. If padding fails, return failure. If `copyBack` is set, the +/// unpadded result is copied back into the original destination tensor. +FailureOr> rewriteAsPaddedOp( + RewriterBase &rewriter, LinalgOp opToPad, + ArrayRef paddingDimensions, ArrayRef padToMultipleOf, + ArrayRef paddingValues, ArrayRef packPaddings, + LinalgOp &paddedOp, bool copyBack); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1623,7 +1623,7 @@ padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); FailureOr> result = rewriteAsPaddedOp(rewriter, target, paddingDimensions, padToMultipleOf, - paddingValues, packPaddings, paddedOp); + paddingValues, packPaddings, paddedOp, getCopyBack()); if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -113,12 +114,11 @@ opOperand->get(), paddingValue, nofold); } -FailureOr> -linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp) { +FailureOr> linalg::rewriteAsPaddedOp( + RewriterBase &rewriter, LinalgOp opToPad, + ArrayRef paddingDimensions, ArrayRef padToMultipleOf, + ArrayRef paddingValues, ArrayRef packPaddings, + LinalgOp &paddedOp, bool copyBack) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -178,7 +178,21 @@ loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } - return paddedSubtensorResults; + + if (!copyBack) + return paddedSubtensorResults; + + // Copy back unpadded results to the original destination (i.e., inits of the + // linalg op), so that the destination buffer of the computation does not + // change. If the padding folds away, this will materizalize as a memcpy + // between two identical buffers, which will then also fold away. + SmallVector copiedBack; + for (auto it : + llvm::zip(paddedSubtensorResults, opToPad.getDpsInitOperands())) { + copiedBack.push_back(rewriter.create( + loc, std::get<0>(it), std::get<1>(it)->get())); + } + return copiedBack; } FailureOr @@ -194,9 +208,10 @@ if (options.padToMultipleOf.has_value()) padToMultipleOf.assign(options.padToMultipleOf->begin(), options.padToMultipleOf->end()); - FailureOr> newResults = rewriteAsPaddedOp( - rewriter, linalgOp, options.paddingDimensions, padToMultipleOf, - options.paddingValues, options.packPaddings, paddedOp); + FailureOr> newResults = + rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, + padToMultipleOf, options.paddingValues, + options.packPaddings, paddedOp, /*copyBack=*/false); if (failed(newResults)) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -19,7 +19,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op // In this case, the pad op is actually empty: we only tile the first dimension @@ -54,7 +55,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[2] @@ -96,7 +98,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -140,7 +143,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -183,7 +187,8 @@ %matmul_padded = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], - padding_dimensions=[0, 1, 2] + padding_dimensions=[0, 1, 2], + copy_back = false } : (!transform.any_op) -> !transform.any_op %pad = transform.get_producer_of_operand %matmul_padded[2] diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -241,3 +241,54 @@ pack_paddings=[1, 1, 1] } : (!transform.any_op) -> !transform.any_op } + +// ----- + +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + +// CHECK-LABEL: @pack_everything +func.func @pack_everything(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[T0:.*]] = tensor.extract_slice % + // CHECK: %[[T1:.*]] = tensor.extract_slice % + // CHECK: %[[T2:.*]] = tensor.extract_slice % + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK-DAG: %[[CST:.*]] = arith.constant 0. + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] nofold + // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] nofold + // CHECK: %[[PAD2:.*]] = tensor.pad %[[T2]] nofold + + // CHECK: %[[T5:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[PAD0]], %[[PAD1]] : tensor<4x7xf32>, tensor<7x5xf32>) + // CHECK-SAME: outs(%[[PAD2]] : tensor<4x5xf32>) + + // Get unpadded result (no-op in this example). + // CHECK: %[[T6:.*]] = tensor.extract_slice %[[T5]] + // Copy back result to the original buffer, so that the destination of the + // computation does not change. + // CHECK: %[[T7:.*]] = bufferization.copy_tensor %[[T6]], %[[T2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + + // CHECK: %[[T8:.*]] = tensor.insert_slice %[[T7]] into %{{.*}} + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pack_paddings=[1, 1, 1] + } : (!transform.any_op) -> !transform.any_op +}