diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -930,7 +930,12 @@ private: void notifyOperationRemoved(Operation *op) override; - void notifyOperationReplaced(Operation *op, ValueRange newValues) override; + void notifyOperationReplaced(Operation *op, Operation *replacement) override; + + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; + + /// Return "true" if the given op handle is still alive. + bool isOpHandleAlive(Value handle) const; /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -405,9 +405,21 @@ /// Notify the listener that the specified operation was modified in-place. virtual void notifyOperationModified(Operation *op) {} + /// Notify the listener that the specified operation is about to be replaced + /// with another operation. This is called before the uses of the operation + /// have changed. + /// + /// Note: The "op replaced with values" notification is triggered in + /// addition to this notification. + virtual void notifyOperationReplaced(Operation *op, + Operation *replacement) {} + /// Notify the listener that the specified operation is about to be replaced /// with the set of values potentially produced by new operations. This is /// called before the uses of the operation have been changed. + /// + /// Note: The list of replacement values is empty in case the specified op + /// has no results. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} @@ -444,6 +456,10 @@ void notifyOperationModified(Operation *op) override { listener->notifyOperationModified(op); } + void notifyOperationReplaced(Operation *op, + Operation *replacement) override { + listener->notifyOperationReplaced(op, replacement); + } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { listener->notifyOperationReplaced(op, replacement); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1279,24 +1279,8 @@ return false; } -void transform::TrackingListener::notifyOperationReplaced( - Operation *op, ValueRange newValues) { - assert(op->getNumResults() == newValues.size() && - "invalid number of replacement values"); - - // Replace value handles. - for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) - (void)replacePayloadValue(oldValue, newValue); - - // Replace op handle. - SmallVector opHandles; - if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { - // Op is not tracked. - return; - } - - // Helper function to check if the current transform op consumes any handle - // that is mapped to `op`. +bool transform::TrackingListener::isOpHandleAlive(Value handle) const { + // Check if the handle is consumed by the current transform op. // // Note: If a handle was consumed, there shouldn't be any alive users, so it // is not really necessary to check for consumed handles. However, in case @@ -1304,38 +1288,81 @@ // a replacement op could not be found, we want to fail with a nicer error // message: "op uses a handle invalidated..." instead of "could not find // replacement op". This nicer error is produced later. - auto handleWasConsumed = [&] { - return llvm::any_of(opHandles, - [&](Value h) { return consumedHandles.contains(h); }); - }; - - // Helper function to check if the handle is alive. - auto hasAliveUser = [&]() { - for (Value v : opHandles) { - for (Operation *user : v.getUsers()) - if (user != transformOp && !happensBefore(user, transformOp)) - return true; - } + if (consumedHandles.contains(handle)) return false; - }; - if (!hasAliveUser() || handleWasConsumed()) { - // The op is tracked but the corresponding handles are dead or were - // consumed. Drop the op form the mapping. + // Check if the handle has further uses. + for (Operation *user : handle.getUsers()) + if (user != transformOp && !happensBefore(user, transformOp)) + return true; + + return false; +} + +void transform::TrackingListener::notifyOperationReplaced( + Operation *op, Operation *replacement) { + assert(op->getNumResults() == replacement->getNumResults() && + "invalid replacement"); + + // Nothing to do if op is not tracked. + SmallVector opHandles; + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) + return; + + // Drop the op from the mapping if there are no more alive handles. + if (!llvm::any_of(opHandles, + [&](Value handle) { return isOpHandleAlive(handle); })) { (void)replacePayloadOp(op, nullptr); return; } - FailureOr replacement = findReplacementOp(op, newValues); + // Handle ops with 0 results here. All other replacements are handled in the + // other `notifyOperationReplaced` overload. + if (op->getNumResults() != 0) + return; + + // Replace op handle if both ops have the same name. + if (op->getName() == replacement->getName()) { + (void)replacePayloadOp(op, replacement); + return; + } + + // The replacement op has a different name. This case is handled in the other + // `notifyOperationReplaced` overload. +} + +void transform::TrackingListener::notifyOperationReplaced( + Operation *op, ValueRange replacement) { + assert(op->getNumResults() == replacement.size() && + "invalid number of replacement values"); + + // Replace value handles. + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), replacement)) + (void)replacePayloadValue(oldValue, newValue); + + // Nothing more to do if op is not tracked. + SmallVector opHandles; + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) + return; + + // Drop the op from the mapping if there are no more alive handles. + if (!llvm::any_of(opHandles, + [&](Value handle) { return isOpHandleAlive(handle); })) { + (void)replacePayloadOp(op, nullptr); + return; + } + + // Replace op handle. + FailureOr replacementOp = findReplacementOp(op, replacement); // If the op is tracked but no replacement op was found, send a // notification. - if (failed(replacement)) { - notifyPayloadReplacementNotFound(op, newValues); + if (failed(replacementOp)) { + notifyPayloadReplacementNotFound(op, replacement); (void)replacePayloadOp(op, nullptr); return; } - (void)replacePayloadOp(op, *replacement); + (void)replacePayloadOp(op, *replacementOp); } transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -370,8 +370,11 @@ Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); - if (op->getNumResults() == 1) - return replaceOp(op, newOp->getResult(0)); + + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(op, newOp); + return replaceOp(op, newOp->getResults()); } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1601,19 +1601,19 @@ // ----- // CHECK-LABEL: func @test_tracked_rewrite() { -// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} -// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.replace_me"} -// CHECK-NEXT: "test.update_mapping"() {original_op = "test.replace_me"} -// CHECK-NEXT: } +// CHECK-NEXT: "test.update_mapping"() {original_op = "test.foo"} +// CHECK-NEXT: "test.drop_mapping"() {original_op = "test.foo"} +// CHECK-NEXT: "test.update_mapping"() {original_op = "test.foo"} func.func @test_tracked_rewrite() { - %0 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) - %1 = "test.replace_me"() {replacement = "test.drop_mapping"} : () -> (i1) - %2 = "test.replace_me"() {replacement = "test.update_mapping"} : () -> (i1) + %0 = "test.foo"() {replace_me, replacement = "test.update_mapping"} : () -> (i1) + %1 = "test.foo"() {replace_me, replacement = "test.drop_mapping"} : () -> (i1) + %2 = "test.foo"() {replace_me, replacement = "test.update_mapping"} : () -> (i1) + return } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["test.replace_me"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-remark @below {{2 iterations}} transform.test_tracked_rewrite %0 : (!transform.any_op) -> () // One replacement op (test.drop_mapping) is dropped from the mapping. @@ -1621,6 +1621,28 @@ test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op } +// ----- + +// CHECK-LABEL: func @test_tracked_rewrite_no_results() { +// CHECK-NEXT: transform.test_dummy_payload_op {original_op = "transform.test_dummy_payload_op"} : () -> () +// CHECK-NEXT: transform.test_dummy_payload_op {original_op = "transform.test_dummy_payload_op"} : () -> () +// CHECK-NEXT: transform.test_dummy_payload_op : () -> () +func.func @test_tracked_rewrite_no_results() { + transform.test_dummy_payload_op {replace_me} : () -> () + transform.test_dummy_payload_op {replace_me} : () -> () + transform.test_dummy_payload_op : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-remark @below {{3 iterations}} + transform.test_tracked_rewrite %0 : (!transform.any_op) -> () + // One replacement op is dropped from the mapping. + // expected-remark @below {{3}} + test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op +} // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -756,23 +756,37 @@ // loop body. Replacement ops are not enumerated. for (Operation *op : state.getPayloadOps(getIn())) { ++numIterations; - rewriterUnderTest.setInsertionPointToEnd(op->getBlock()); + (void)op; // Erase all payload ops. The outer loop should have only one iteration. for (Operation *op : state.getPayloadOps(getIn())) { - if (op->getName().getStringRef() != "test.replace_me") + rewriterUnderTest.setInsertionPoint(op); + if (!op->hasAttr("replace_me")) continue; + + // Check if replacement op name is specified. In that case, create the op + // and replace via values. auto replacementName = op->getAttrOfType("replacement"); - if (!replacementName) + if (replacementName) { + SmallVector attributes; + attributes.emplace_back(rewriter.getStringAttr("original_op"), + op->getName().getIdentifier()); + OperationState opState(op->getLoc(), replacementName, + /*operands=*/ValueRange(), + /*types=*/op->getResultTypes(), attributes); + Operation *newOp = rewriterUnderTest.create(opState); + rewriterUnderTest.replaceOp(op, newOp->getResults()); continue; - SmallVector attributes; - attributes.emplace_back(rewriter.getStringAttr("original_op"), - op->getName().getIdentifier()); - OperationState opState(op->getLoc(), replacementName, - /*operands=*/ValueRange(), - /*types=*/op->getResultTypes(), attributes); - Operation *newOp = rewriterUnderTest.create(opState); - rewriterUnderTest.replaceOp(op, newOp->getResults()); + } + + // Otherwise, replace with a dummy op and use the `replaceOpWithNewOp` + // function, which triggers an "op replaced with op" notification. + StringRef oldOpName = op->getName().getStringRef(); + Operation *newOp = + rewriterUnderTest.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands()); + newOp->setAttr(rewriter.getStringAttr("original_op"), + rewriter.getStringAttr(oldOpName)); } } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -475,6 +475,30 @@ let cppNamespace = "::mlir::test"; } +// This op is used as a payload op. It must a registered op, so that it can be +// created with "RewriterBase::replaceOpWithNewOp" (needed for a test case). +// Since only TransformOpInterface can be injected into the transform dialect, +// this op implements the interface, even though it is not used as a transform +// op. +def TestDummyPayloadOp + : Op { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$outs); + let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + + let extraClassDeclaration = [{ + void getEffects(SmallVectorImpl &effects) {} + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + llvm_unreachable("op should not be used as a transform"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + }]; +} + def TestTrackedRewriteOp : Op,