diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -222,14 +222,19 @@ /// Populates `handles` with all handles pointing to the given Payload IR op. /// Returns success if such handles exist, failure otherwise. + /// If `includeOutOfScope` is set to "true", handles that are defined in + /// regions beyond the most recent isolated from above region are included. LogicalResult getHandlesForPayloadOp(Operation *op, - SmallVectorImpl &handles) const; + SmallVectorImpl &handles, + bool includeOutOfScope = false) const; /// Populates `handles` with all handles pointing to the given payload IR /// value. Returns success if such handles exist, failure otherwise. - LogicalResult - getHandlesForPayloadValue(Value payloadValue, - SmallVectorImpl &handles) const; + /// If `includeOutOfScope` is set to "true", handles that are defined in + /// regions beyond the most recent isolated from above region are included. + LogicalResult getHandlesForPayloadValue(Value payloadValue, + SmallVectorImpl &handles, + bool includeOutOfScope = false) const; /// Applies the transformation specified by the given transform op and updates /// the state accordingly. @@ -410,42 +415,53 @@ const TransformOptions &options = TransformOptions()); /// Returns the mappings frame for the region in which the value is defined. - const Mappings &getMapping(Value value) const { - return const_cast(this)->getMapping(value); + /// If `allowOutOfScope` is set to "false", asserts that the value is in + /// scope, based on the current stack of frames. + const Mappings &getMapping(Value value, bool allowOutOfScope = false) const { + return const_cast(this)->getMapping(value, + allowOutOfScope); } - Mappings &getMapping(Value value) { + Mappings &getMapping(Value value, bool allowOutOfScope = false) { Region *region = value.getParentRegion(); auto it = mappings.find(region); assert(it != mappings.end() && "trying to find a mapping for a value from an unmapped region"); #ifndef NDEBUG - for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { - if (r == region) - break; - if (r->getParentOp()->hasTrait()) - llvm_unreachable( - "trying to get mapping beyond region that is isolated from above"); + if (!allowOutOfScope) { + for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { + if (r == region) + break; + if (r->getParentOp()->hasTrait()) + llvm_unreachable("trying to get mapping beyond region that is " + "isolated from above"); + } } #endif // NDEBUG return it->second; } /// Returns the mappings frame for the region in which the operation resides. - const Mappings &getMapping(Operation *operation) const { - return const_cast(this)->getMapping(operation); + /// If `allowOutOfScope` is set to "false", asserts that the operation is in + /// scope, based on the current stack of frames. + const Mappings &getMapping(Operation *operation, + bool allowOutOfScope = false) const { + return const_cast(this)->getMapping(operation, + allowOutOfScope); } - Mappings &getMapping(Operation *operation) { + Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) { Region *region = operation->getParentRegion(); auto it = mappings.find(region); assert(it != mappings.end() && "trying to find a mapping for an operation from an unmapped region"); #ifndef NDEBUG - for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { - if (r == region) - break; - if (r->getParentOp()->hasTrait()) - llvm_unreachable( - "trying to get mapping beyond region that is isolated from above"); + if (!allowOutOfScope) { + for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { + if (r == region) + break; + if (r->getParentOp()->hasTrait()) + llvm_unreachable("trying to get mapping beyond region that is " + "isolated from above"); + } } #endif // NDEBUG return it->second; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -83,7 +83,8 @@ } LogicalResult transform::TransformState::getHandlesForPayloadOp( - Operation *op, SmallVectorImpl &handles) const { + Operation *op, SmallVectorImpl &handles, + bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { auto iterator = mapping.reverse.find(op); @@ -92,7 +93,8 @@ found = true; } // Stop looking when reaching a region that is isolated from above. - if (region->getParentOp()->hasTrait()) + if (!includeOutOfScope && + region->getParentOp()->hasTrait()) break; } @@ -100,7 +102,8 @@ } LogicalResult transform::TransformState::getHandlesForPayloadValue( - Value payloadValue, SmallVectorImpl &handles) const { + Value payloadValue, SmallVectorImpl &handles, + bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { auto iterator = mapping.reverseValues.find(payloadValue); @@ -109,7 +112,8 @@ found = true; } // Stop looking when reaching a region that is isolated from above. - if (region->getParentOp()->hasTrait()) + if (!includeOutOfScope && + region->getParentOp()->hasTrait()) break; } @@ -343,7 +347,8 @@ #ifndef NDEBUG for (Value opResult : op->getResults()) { SmallVector valueHandles; - (void)getHandlesForPayloadValue(opResult, valueHandles); + (void)getHandlesForPayloadValue(opResult, valueHandles, + /*includeOutOfScope=*/true); assert(valueHandles.empty() && "expected no mapping to old results"); } #endif // NDEBUG @@ -351,10 +356,10 @@ // Drop the mapping between the op and all handles that point to it. Fail if // there are no handles. SmallVector opHandles; - if (failed(getHandlesForPayloadOp(op, opHandles))) + if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true))) return failure(); for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); + Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); dropMappingEntry(mappings.reverse, op, handle); } @@ -385,7 +390,7 @@ // element from an array invalidates iterators; merely changing the value of // elements does not. for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); + Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); auto it = mappings.direct.find(handle); if (it == mappings.direct.end()) continue; @@ -410,11 +415,12 @@ LogicalResult transform::TransformState::replacePayloadValue(Value value, Value replacement) { SmallVector valueHandles; - if (failed(getHandlesForPayloadValue(value, valueHandles))) + if (failed(getHandlesForPayloadValue(value, valueHandles, + /*includeOutOfScope=*/true))) return failure(); for (Value handle : valueHandles) { - Mappings &mappings = getMapping(handle); + Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); dropMappingEntry(mappings.reverseValues, value, handle); // If replacing with null, that is erasing the mapping, drop the mapping @@ -764,7 +770,7 @@ void transform::TransformState::compactOpHandles() { for (Value handle : opHandlesToCompact) { - Mappings &mappings = getMapping(handle); + Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); llvm::erase_value(mappings.direct[handle], nullptr); } opHandlesToCompact.clear(); @@ -1346,7 +1352,8 @@ // Replace op handle. SmallVector opHandles; - if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { + if (failed(getTransformState().getHandlesForPayloadOp( + op, opHandles, /*includeOutOfScope=*/true))) { // Op is not tracked. return; } diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -131,11 +131,47 @@ transform.apply_patterns to %0 { transform.apply_patterns.transform.test_patterns } : !transform.any_op + // No marker should be printed. transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op } // ----- +// CHECK-LABEL: func @erase_tracked_op_in_named_sequence() +// CHECK: "test.container"() ({ +// CHECK-NEXT: ^bb0: +// CHECK-NEXT: }) : () -> () +module { + func.func @erase_tracked_op_in_named_sequence() { + "test.container"() ({ + // expected-remark @below {{matched op}} + %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32) + }) : () -> () + return + } + + module attributes { transform.with_named_sequence } { + transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () { + transform.apply_patterns to %arg0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op + include @foo failures(propagate) (%0) : (!transform.any_op) -> () + // No marker should be printed. + transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op + } + } +} + +// ----- + // CHECK-LABEL: func @canonicalization( // CHECK: %[[c5:.*]] = arith.constant 5 : index // CHECK: return %[[c5]]