diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -951,7 +951,7 @@ DefaultValuedAttr< TypedArrayAttrBase, "{}">:$transpose_paddings, - DefaultValuedAttr:$copy_back); + DefaultValuedAttr:$copy_back_op); let results = (outs TransformHandleTypeInterface:$padded, TransformHandleTypeInterface:$pad); @@ -970,7 +970,7 @@ CArg<"ArrayRef", "{}">:$padToMultipleOf, CArg<"ArrayRef", "{}">:$packPaddings, CArg<"ArrayRef", "{}">:$transposePaddings, - CArg<"bool", "false">:$copyBack)> + CArg<"StringRef", "\"bufferization.copy_tensor\"">:$copyBackOp)> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -291,6 +291,18 @@ transposePaddings.assign(tp.begin(), tp.end()); return *this; } + enum class CopyBackOp : int8_t { + None = 0, + BufferizationCopyTensor = 1, + LinalgCopy = 2 + }; + /// The op to be used for copying the padded result to the original + /// destination tensor. + CopyBackOp copyBackOp = CopyBackOp::BufferizationCopyTensor; + LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) { + copyBackOp = op; + return *this; + } }; /// Callback function type used to perform the allocation for the promoted @@ -454,14 +466,13 @@ /// * The unpadded results (extracted slice of the cloned operation) are /// returned via `replacements`. /// * The tensor::PadOps are returned via `padOps`. -/// * If `copyBack` is set to "true", the unpadded result is copied back to the -/// original destination tensor. +/// * "options.copyBackOp" specifies the op type for copying back the unpadded +/// result to the original destination tensor. LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector &replacements, - SmallVector &padOps, - bool copyBack); + SmallVector &padOps); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1585,7 +1585,7 @@ ArrayRef padToMultipleOf, ArrayRef packPaddings, ArrayRef transposePaddings, - bool copyBack) { + StringRef copyBackOp) { auto resultType = transform::AnyOpType::get(b.getContext()); return build(/*builder=*/b, /*result=*/result, @@ -1598,7 +1598,7 @@ : b.getI64ArrayAttr(padToMultipleOf)), /*packPaddings=*/b.getI64ArrayAttr(packPaddings), /*transposePaddings=*/b.getArrayAttr(transposePaddings), - /*copyBack=*/b.getBoolAttr(copyBack)); + /*copyBackOp=*/b.getStringAttr(copyBackOp)); } DiagnosedSilenceableFailure @@ -1672,11 +1672,21 @@ options.padToMultipleOf = padToMultipleOf; options.paddingValues = paddingValues; options.packPaddings = packPaddings; + if (getCopyBackOp() == "bufferization.copy_tensor") { + options.copyBackOp = + LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor; + } else if (getCopyBackOp() == "linalg.copy") { + options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy; + } else if (getCopyBackOp() == "none") { + options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None; + } else { + llvm_unreachable("unsupported copy_back op"); + } SmallVector replacements; SmallVector newPadOps; if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp, - replacements, newPadOps, getCopyBack()))) { + replacements, newPadOps))) { auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; @@ -1732,6 +1742,9 @@ << attr; } } + if (getCopyBackOp() != "bufferization.copy_tensor" && + getCopyBackOp() != "linalg.copy" && getCopyBackOp() != "none") + return emitOpError() << "invalid copy_back_op"; return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -151,7 +151,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &constOptions, LinalgOp &paddedOp, SmallVector &replacements, - SmallVector &padOps, bool copyBack) { + SmallVector &padOps) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -225,7 +225,7 @@ strides)); } - if (!copyBack) { + if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) { replacements = std::move(paddedSubtensorResults); return success(); } @@ -239,8 +239,18 @@ "expected matching number of results"); for (auto it : llvm::zip(paddedSubtensorResults, opToPad.getDpsInitOperands())) { - replacements.push_back(rewriter.create( - loc, std::get<0>(it), std::get<1>(it)->get())); + if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { + replacements.push_back(rewriter + .create(loc, std::get<0>(it), + std::get<1>(it)->get()) + .getResult(0)); + } else if (options.copyBackOp == + LinalgPaddingOptions::CopyBackOp::BufferizationCopyTensor) { + replacements.push_back(rewriter.create( + loc, std::get<0>(it), std::get<1>(it)->get())); + } else { + llvm_unreachable("unsupported copy back op"); + } } return success(); } @@ -248,6 +258,9 @@ FailureOr mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, const LinalgPaddingOptions &options) { + assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None && + "invalid options"); + if (!linalgOp.hasTensorSemantics()) return rewriter.notifyMatchFailure( linalgOp, "only applies to Linalg ops with tensor semantics"); @@ -257,7 +270,7 @@ SmallVector newResults; SmallVector padOps; if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp, - newResults, padOps, /*copyBack=*/false))) + newResults, padOps))) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); diff --git a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir @@ -20,7 +20,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // In this case, the pad op is actually empty: we only tile the first dimension @@ -56,7 +56,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[2] @@ -99,7 +99,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -144,7 +144,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[0] @@ -188,7 +188,7 @@ %matmul_padded, %0 = transform.structured.pad %matmul_l1 { padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], - copy_back = false + copy_back_op = "none" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %pad = transform.get_producer_of_operand %matmul_padded[2]