diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -9,7 +9,9 @@ #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -930,6 +930,14 @@ provides as operation attributes. The operation returns a handle to the padded operation and to the padding operation ("tensor.pad"). + To preserve tensor SSA use-def chains, the unpadded result is copied back to + the original destination tensor of the targeted op. The op that copies back + the result can be customized with `copy_back_op`: + + * "bufferization.copy_tensor" (default) + * "linalg.copy" + * "none" (no copy back) + #### Return modes This operation ignores non-Linalg ops and drops them in the return. @@ -951,7 +959,7 @@ DefaultValuedAttr< TypedArrayAttrBase, "{}">:$transpose_paddings, - DefaultValuedAttr:$copy_back); + DefaultValuedAttr:$copy_back_op); let results = (outs TransformHandleTypeInterface:$padded, TransformHandleTypeInterface:$pad); @@ -970,10 +978,13 @@ CArg<"ArrayRef", "{}">:$padToMultipleOf, CArg<"ArrayRef", "{}">:$packPaddings, CArg<"ArrayRef", "{}">:$transposePaddings, - CArg<"bool", "false">:$copyBack)> + CArg<"StringRef", "::mlir::bufferization::CopyTensorOp::getOperationName()">:$copyBackOp)> ]; let extraClassDeclaration = [{ + /// copy_back_op attribute value indicating that no copy back is desired. + static constexpr StringRef kCopyOpNone = "none"; + ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -291,6 +291,18 @@ transposePaddings.assign(tp.begin(), tp.end()); return *this; } + enum class CopyBackOp : int8_t { + None = 0, + BufferizationCopyTensor = 1, + LinalgCopy = 2 + }; + /// The op to be used for copying the padded result to the original + /// destination tensor. + CopyBackOp copyBackOp = CopyBackOp::BufferizationCopyTensor; + LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) { + copyBackOp = op; + return *this; + } }; /// Callback function type used to perform the allocation for the promoted @@ -473,14 +485,13 @@ /// * The unpadded results (extracted slice of the cloned operation) are /// returned via `replacements`. /// * The tensor::PadOps are returned via `padOps`. -/// * If `copyBack` is set to "true", the unpadded result is copied back to the -/// original destination tensor. +/// * "options.copyBackOp" specifies the op type for copying back the unpadded +/// result to the original destination tensor. LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector &replacements, - SmallVector &padOps, - bool copyBack); + SmallVector &padOps); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -16,6 +16,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRBufferizationDialect MLIRBufferizationTransforms MLIRFuncDialect MLIRIR diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -1591,7 +1592,7 @@ ArrayRef padToMultipleOf, ArrayRef packPaddings, ArrayRef transposePaddings, - bool copyBack) { + StringRef copyBackOp) { auto resultType = transform::AnyOpType::get(b.getContext()); return build(/*builder=*/b, /*result=*/result, @@ -1604,7 +1605,7 @@ : b.getI64ArrayAttr(padToMultipleOf)), /*packPaddings=*/b.getI64ArrayAttr(packPaddings), /*transposePaddings=*/b.getArrayAttr(transposePaddings), - /*copyBack=*/b.getBoolAttr(copyBack)); + /*copyBackOp=*/b.getStringAttr(copyBackOp)); } DiagnosedSilenceableFailure @@ -1678,11 +1679,21 @@ options.padToMultipleOf = padToMultipleOf; options.paddingValues = paddingValues; options.packPaddings = packPaddings; + if (getCopyBackOp() == bufferization::CopyTensorOp::getOperationName()) { + options.copyBackOp = + LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor; + } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) { + options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy; + } else if (getCopyBackOp() == kCopyOpNone) { + options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None; + } else { + llvm_unreachable("unsupported copy_back op"); + } SmallVector replacements; SmallVector newPadOps; if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp, - replacements, newPadOps, getCopyBack()))) { + replacements, newPadOps))) { auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; @@ -1738,6 +1749,10 @@ << attr; } } + if (getCopyBackOp() != bufferization::CopyTensorOp::getOperationName() && + getCopyBackOp() != linalg::CopyOp::getOperationName() && + getCopyBackOp() != kCopyOpNone) + return emitOpError() << "invalid copy_back_op"; return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -151,7 +151,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &constOptions, LinalgOp &paddedOp, SmallVector &replacements, - SmallVector &padOps, bool copyBack) { + SmallVector &padOps) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -225,7 +225,7 @@ strides)); } - if (!copyBack) { + if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) { replacements = std::move(paddedSubtensorResults); return success(); } @@ -239,8 +239,18 @@ "expected matching number of results"); for (auto it : llvm::zip(paddedSubtensorResults, opToPad.getDpsInitOperands())) { - replacements.push_back(rewriter.create( - loc, std::get<0>(it), std::get<1>(it)->get())); + if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { + replacements.push_back(rewriter + .create(loc, std::get<0>(it), + std::get<1>(it)->get()) + .getResult(0)); + } else if (options.copyBackOp == + LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor) { + replacements.push_back(rewriter.create( + loc, std::get<0>(it), std::get<1>(it)->get())); + } else { + llvm_unreachable("unsupported copy back op"); + } } return success(); } @@ -248,6 +258,9 @@ FailureOr mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, const LinalgPaddingOptions &options) { + assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None && + "invalid options"); + if (!linalgOp.hasTensorSemantics()) return rewriter.notifyMatchFailure( linalgOp, "only applies to Linalg ops with tensor semantics"); @@ -257,7 +270,7 @@ SmallVector newResults; SmallVector padOps; if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp, - newResults, padOps, /*copyBack=*/false))) + newResults, padOps))) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -20,7 +20,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // In this case, the pad op is actually empty: we only tile the first dimension @@ -56,7 +56,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[2] @@ -99,7 +99,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -144,7 +144,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -188,7 +188,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[2] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9652,6 +9652,7 @@ ":Analysis", ":ArithDialect", ":AsmParser", + ":BufferizationDialect", ":BufferizationTransforms", ":DialectUtils", ":FuncDialect",