diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -92,7 +92,7 @@ void transform::TransformState::dropReverseMapping(Mappings &mappings, Operation *op, Value value) { auto it = mappings.reverse.find(op); - if (it != mappings.reverse.end()) + if (it == mappings.reverse.end()) return; llvm::erase_value(it->getSecond(), value); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -895,3 +895,28 @@ transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation } } + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 : !pdl.operation failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation + // here, the handles nested under are {%arg0, %arg1, %0} + // expected-remark @below {{3 handles nested under}} + transform.test_report_number_of_tracked_handles_nested_under %arg1 + // expected-remark @below {{erased}} + transform.test_emit_remark_and_erase_operand %0, "erased" + // here, the handles nested under are only {%arg0, %arg1} + // expected-remark @below {{2 handles nested under}} + transform.test_report_number_of_tracked_handles_nested_under %arg1 + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } +} + +"test.some_op"() : () -> () diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -328,6 +328,26 @@ return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( + transform::TransformResults &results, transform::TransformState &state) { + int64_t count = 0; + for (Operation *op : state.getPayloadOps(getTarget())) { + op->walk([&](Operation *nested) { + SmallVector handles; + (void)state.getHandlesForPayloadOp(nested, handles); + count += handles.size(); + }); + } + emitRemark() << count << " handles nested under"; + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -253,4 +253,13 @@ let assemblyFormat = "$handle attr-dict"; } +def TestReportNumberOfTrackedHandlesNestedUnder + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins PDL_Operation:$target); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD