diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -923,9 +923,12 @@ TransformOpInterface getTransformOp() const { return transformOp; } private: + friend class TransformRewriter; + void notifyOperationRemoved(Operation *op) override; void notifyOperationReplaced(Operation *op, ValueRange newValues) override; + using Listener::notifyOperationReplaced; /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; @@ -981,6 +984,19 @@ /// Silence all tracking failures that have been encountered so far. void silenceTrackingFailure(); + /// Notify the transform dialect interpreter that the given op has been + /// replaced with another op and that the mapping between handles and payload + /// ops/values should be updated. This function should be called before the + /// original op is erased. It fails if the operation could not be replaced, + /// e.g., because the original operation is not tracked. + /// + /// Note: As long as IR modifications are performed through this rewriter, + /// the transform state is usually updated automatically. This function should + /// be used when unsupported rewriter API is used; e.g., updating all uses of + /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`. + LogicalResult notifyPayloadOperationReplaced(Operation *op, + Operation *replacement); + private: ErrorCheckingTrackingListener *const listener; }; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -688,36 +688,6 @@ return true; } -namespace { -/// Unsafely exposes an internal protected method of TransformState::Extension -/// as public. -/// -/// MUST NOT be used directly. -class UnsafeOpReplacementStateExtension : public TransformState::Extension { -public: - UnsafeOpReplacementStateExtension(TransformState &state) - : TransformState::Extension(state) {} - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - UnsafeOpReplacementStateExtension) - - LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) { - return replacePayloadOp(op, replacement); - } -}; -} // namespace - -/// Replaces `payload` with `replacement` in all handles stored in the state. -/// MUST NOT be used except for the case immediately below. -static void forciblyReplaceReferencedPayloadOperation(TransformState &state, - Operation *payload, - Operation *replacement) { - UnsafeOpReplacementStateExtension extension(state); - // This may return failure if the payload is not associated with any handle, - // ignore that. - (void)extension.doReplacePayloadOp(payload, replacement); -} - DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -787,6 +757,19 @@ LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); fusedOps.append(tiledOps); if (newContainingOp) { + // Update handles associated with the containing op so we don't need to + // invalidate them. This is a hack to support better composability + // between tiling and fusion while a proper mechanism is being + // investigated. + // + // DO NOT replicate this elsewhere unless you understand what you are + // doing. + LogicalResult replacementStatus = + rewriter.notifyPayloadOperationReplaced(containingOp, + newContainingOp); + (void)replacementStatus; + assert(succeeded(replacementStatus) && + "unable to update transform state mapping"); rewriter.eraseOp(containingOp); containingOp = newContainingOp; } @@ -813,14 +796,6 @@ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - // Update handles associated with the containing op so we don't need to - // invalidate them. This is a hack to support better composability between - // tiling and fusion while a proper mechanism is being investigated. - // - // DO NOT replicate this elsewhere unless you understand what you are doing. - forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(), - containingOp); - results.set(cast(getFusedOp()), fusedOps); results.set(cast(getNewContainingOp()), {containingOp}); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1452,6 +1452,11 @@ } } +LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced( + Operation *op, Operation *replacement) { + return listener->replacePayloadOp(op, replacement); +} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> @@ -323,6 +323,7 @@ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) // CHECK-SAME: -> (tensor, tensor) { + // expected-remark @below{{new containing op}} %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { // CHECK: %[[I0:.*]] = affine.apply {{.*}} %3 = affine.apply #map1(%i)[%idx] @@ -350,8 +351,9 @@ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. - transform.structured.fuse_into_containing_op %0 into %1 + %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) + test_print_remark_at_operand %containing, "new containing op" : !transform.any_op } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1710,3 +1710,20 @@ transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param transform.annotate %0 "unit_attr" : !transform.any_op } + +// ----- + +func.func @notify_payload_op_replaced(%arg0: index, %arg1: index) { + %0 = arith.muli %arg0, %arg1 {original} : index + // expected-remark @below{{updated handle}} + %1 = arith.muli %arg0, %arg1 {replacement} : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match attributes{original} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match attributes{replacement} in %arg1 : (!transform.any_op) -> !transform.any_op + test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> () + test_print_remark_at_operand %0, "updated handle" : !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -872,6 +872,34 @@ return success(); } +DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto originalOps = state.getPayloadOps(getOriginal()); + auto replacementOps = state.getPayloadOps(getReplacement()); + if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) + return emitSilenceableError() << "expected same number of original and " + "replacement payload operations"; + for (const auto &[original, replacement] : + llvm::zip(originalOps, replacementOps)) { + if (failed( + rewriter.notifyPayloadOperationReplaced(original, replacement))) { + auto diag = emitSilenceableError() + << "unable to replace payload op in transform mapping"; + diag.attachNote(original->getLoc()) << "original payload op"; + diag.attachNote(replacement->getLoc()) << "replacement payload op"; + return diag; + } + } + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getOriginal(), effects); + transform::onlyReadsHandle(getReplacement(), effects); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -564,4 +564,15 @@ let hasVerifier = 1; } +def TestNotifyPayloadOpReplacedOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformHandleTypeInterface:$original, + TransformHandleTypeInterface:$replacement); + let results = (outs); + let assemblyFormat = "$original `,` $replacement attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD