diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -930,7 +930,14 @@ private: void notifyOperationRemoved(Operation *op) override; - void notifyOperationReplaced(Operation *op, ValueRange newValues) override; + void notifyOperationReplaced(Operation *op, Operation *replacement) override; + + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; + + /// Return "true" if a payload op (that is currently in the process of being + /// replaced) that is mapped to given handles can be silently dropped from the + /// mapping. This is the case for dead/consumed handles. + bool canSilentlyDropPayloadOpFromMapping(ValueRange handles) const; /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -405,9 +405,21 @@ /// Notify the listener that the specified operation was modified in-place. virtual void notifyOperationModified(Operation *op) {} + /// Notify the listener that the specified operation is about to be replaced + /// with another operation. This is called before the uses of the operation + /// have changed. + /// + /// Note: The "op replaced with values" notification is triggered in + /// addition to this notification. + virtual void notifyOperationReplaced(Operation *op, + Operation *replacement) {} + /// Notify the listener that the specified operation is about to be replaced /// with the set of values potentially produced by new operations. This is /// called before the uses of the operation have been changed. + /// + /// Note: The list of replacement values is empty in case the specified op + /// has no results. virtual void notifyOperationReplaced(Operation *op, ValueRange replacement) {} @@ -444,6 +456,10 @@ void notifyOperationModified(Operation *op) override { listener->notifyOperationModified(op); } + void notifyOperationReplaced(Operation *op, + Operation *replacement) override { + listener->notifyOperationReplaced(op, replacement); + } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { listener->notifyOperationReplaced(op, replacement); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1279,63 +1279,93 @@ return false; } +bool transform::TrackingListener::canSilentlyDropPayloadOpFromMapping( + ValueRange handles) const { + // Check if any of the handles is consumed by the current transform op. + // + // Note: If a handle was consumed, there shouldn't be any alive handles + // mapped to the payload op. So it is not really necessary to check for + // consumed handles here. However, in case there are indeed alive handles + // mapped to the payload op (which is invalid IR) and a replacement op could + // not be found, we want to fail with a nicer error message: "op uses a handle + // invalidated..." instead of "could not find replacement op". This nicer + // error is produced later. + if (llvm::any_of(handles, [&](Value handle) { + return consumedHandles.contains(handle); + })) + return true; + + // Check if the payload op is mapped to any alive handle. + for (Value handle : handles) + for (Operation *user : handle.getUsers()) + if (user != transformOp && !happensBefore(user, transformOp)) + return false; + + return true; +} + void transform::TrackingListener::notifyOperationReplaced( - Operation *op, ValueRange newValues) { - assert(op->getNumResults() == newValues.size() && + Operation *op, Operation *replacement) { + assert(op->getNumResults() == replacement->getNumResults() && + "invalid replacement"); + + // Nothing to do if op is not tracked. + SmallVector opHandles; + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) + return; + + // Drop the op from the mapping if there are no more alive handles. + if (canSilentlyDropPayloadOpFromMapping(opHandles)) { + (void)replacePayloadOp(op, nullptr); + return; + } + + // Handle ops with 0 results here. All other replacements are handled in the + // other `notifyOperationReplaced` overload. + if (op->getNumResults() != 0) + return; + + // Replace op handle if both ops have the same name. + if (op->getName() == replacement->getName()) { + (void)replacePayloadOp(op, replacement); + return; + } + + // The replacement op has a different name. This case is handled in the other + // `notifyOperationReplaced` overload. +} + +void transform::TrackingListener::notifyOperationReplaced( + Operation *op, ValueRange replacement) { + assert(op->getNumResults() == replacement.size() && "invalid number of replacement values"); // Replace value handles. - for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), replacement)) (void)replacePayloadValue(oldValue, newValue); - // Replace op handle. + // Nothing more to do if op is not tracked. SmallVector opHandles; - if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { - // Op is not tracked. + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) return; - } - // Helper function to check if the current transform op consumes any handle - // that is mapped to `op`. - // - // Note: If a handle was consumed, there shouldn't be any alive users, so it - // is not really necessary to check for consumed handles. However, in case - // there are indeed alive handles that were consumed (which is invalid IR) and - // a replacement op could not be found, we want to fail with a nicer error - // message: "op uses a handle invalidated..." instead of "could not find - // replacement op". This nicer error is produced later. - auto handleWasConsumed = [&] { - return llvm::any_of(opHandles, - [&](Value h) { return consumedHandles.contains(h); }); - }; - - // Helper function to check if the handle is alive. - auto hasAliveUser = [&]() { - for (Value v : opHandles) { - for (Operation *user : v.getUsers()) - if (user != transformOp && !happensBefore(user, transformOp)) - return true; - } - return false; - }; - - if (!hasAliveUser() || handleWasConsumed()) { - // The op is tracked but the corresponding handles are dead or were - // consumed. Drop the op form the mapping. + // Drop the op from the mapping if there are no more alive handles. + if (canSilentlyDropPayloadOpFromMapping(opHandles)) { (void)replacePayloadOp(op, nullptr); return; } - FailureOr replacement = findReplacementOp(op, newValues); + // Replace op handle. + FailureOr replacementOp = findReplacementOp(op, replacement); // If the op is tracked but no replacement op was found, send a // notification. - if (failed(replacement)) { - notifyPayloadReplacementNotFound(op, newValues); + if (failed(replacementOp)) { + notifyPayloadReplacementNotFound(op, replacement); (void)replacePayloadOp(op, nullptr); return; } - (void)replacePayloadOp(op, *replacement); + (void)replacePayloadOp(op, *replacementOp); } transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -370,8 +370,11 @@ Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); - if (op->getNumResults() == 1) - return replaceOp(op, newOp->getResult(0)); + + // Notify the listener that we're about to replace this op. + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationReplaced(op, newOp); + return replaceOp(op, newOp->getResults()); } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1606,9 +1606,9 @@ // CHECK-NEXT: return // CHECK-NEXT: } func.func @test_tracked_rewrite() { - %0 = transform.test_dummy_payload_op {replace_me} : () -> (i1) + %0 = transform.test_dummy_payload_op {replace_me, test_replace_op} : () -> (i1) %1 = transform.test_dummy_payload_op {erase_me} : () -> (i1) - %2 = transform.test_dummy_payload_op {replace_me} : () -> (i1) + %2 = transform.test_dummy_payload_op {replace_me, test_replace_op} : () -> (i1) func.return } @@ -1622,6 +1622,29 @@ test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op } +// ----- + +// CHECK-LABEL: func @test_tracked_rewrite_no_results() { +// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> () +// CHECK-NEXT: transform.test_dummy_payload_op {new_op} : () -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func @test_tracked_rewrite_no_results() { + transform.test_dummy_payload_op {replace_me} : () -> () + transform.test_dummy_payload_op {erase_me} : () -> () + transform.test_dummy_payload_op {replace_me} : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["transform.test_dummy_payload_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-remark @below {{2 iterations}} + transform.test_tracked_rewrite %0 : (!transform.any_op) -> () + // One replacement op is dropped from the mapping. + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op +} // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -748,14 +748,25 @@ continue; } - SmallVector attributes; - attributes.emplace_back(rewriter.getStringAttr("new_op"), - rewriter.getUnitAttr()); - OperationState opState(op->getLoc(), op->getName().getIdentifier(), - /*operands=*/ValueRange(), - /*types=*/op->getResultTypes(), attributes); - Operation *newOp = rewriter.create(opState); - rewriter.replaceOp(op, newOp->getResults()); + // RewriterBase::replaceOp test, which triggers an "op replaced with + // values" notification. + if (op->hasAttr("test_replace_op")) { + SmallVector attributes; + attributes.emplace_back(rewriter.getStringAttr("new_op"), + rewriter.getUnitAttr()); + OperationState opState(op->getLoc(), op->getName().getIdentifier(), + /*operands=*/ValueRange(), + /*types=*/op->getResultTypes(), attributes); + Operation *newOp = rewriter.create(opState); + rewriter.replaceOp(op, newOp->getResults()); + continue; + } + + // Test RewriterBase::replaceOpWithNewOp, which triggers an "op replaced + // with op" notification. + Operation *newOp = rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands()); + newOp->setAttr(rewriter.getStringAttr("new_op"), rewriter.getUnitAttr()); } }