diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -279,14 +279,7 @@ /// Forgets the mapping from or to values defined in the associated /// transform IR region, and restores the mapping that existed before /// entering this scope. - ~RegionScope() { - state.mappings.erase(region); - if (storedMappings.has_value()) - state.mappings.swap(*storedMappings); -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - state.regionStack.pop_back(); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - } + ~RegionScope(); private: /// Tag structure for differentiating the constructor for isolated regions. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1075,6 +1075,51 @@ return state.replacePayloadValue(value, replacement); } +//===----------------------------------------------------------------------===// +// TransformState::RegionScope +//===----------------------------------------------------------------------===// + +transform::TransformState::RegionScope::~RegionScope() { + // Remove handle invalidation notices as handles are going out of scope. + // The same region may be re-entered leading to incorrect invalidation + // errors. + for (Block &block : *region) { + for (Value handle : block.getArguments()) { + state.invalidatedHandles.erase(handle); + } + for (Operation &op : block) { + for (Value handle : op.getResults()) { + state.invalidatedHandles.erase(handle); + } + } + } + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + // Remember pointers to payload ops referenced by the handles going out of + // scope. + SmallVector referencedOps = + llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse)); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + + state.mappings.erase(region); + if (storedMappings.has_value()) + state.mappings.swap(*storedMappings); + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + // If the last handle to a payload op has gone out of scope, we no longer + // need to store the cached name. Pointers may get reused, leading to + // incorrect associations in the cache. + for (Operation *op : referencedOps) { + SmallVector handles; + if (succeeded(state.getHandlesForPayloadOp(op, handles))) + continue; + state.cachedNames.erase(op); + } + + state.regionStack.pop_back(); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +} + //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -364,3 +364,49 @@ transform.test_consume_operand %0 { allow_repeated_handles } : !transform.any_op } } + +// ----- + +// Re-entering the region should not trigger the consumption error from previous +// execution of the region. + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + transform.test_re_enter_region { + %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op + transform.test_consume_operand %0 : !transform.any_op + transform.yield + } +} + +// ----- + +// Re-entering the region should not trigger the consumption error from previous +// execution of the region. + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op + transform.test_re_enter_region %0 : !transform.any_op { + ^bb0(%arg1: !transform.any_op): + transform.test_consume_operand %arg1 : !transform.any_op + transform.yield + } +} + +// ----- + +// Consuming the same handle repeatedly in the region should trigger an error. + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-note @below {{payload op}} + // expected-note @below {{handle to invalidated ops}} + %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op + transform.test_re_enter_region { + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + // expected-note @below {{invalidated by this transform op}} + transform.test_consume_operand %0 : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -831,6 +831,47 @@ patterns.insert(patterns.getContext()); } +void mlir::test::TestReEnterRegionOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getOperands(), effects); + transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + + SmallVector> mappings; + for (BlockArgument arg : getBody().front().getArguments()) { + mappings.emplace_back(llvm::to_vector(llvm::map_range( + state.getPayloadOps(getOperand(arg.getArgNumber())), + [](Operation *op) -> transform::MappedValue { return op; }))); + } + + for (int i = 0; i < 4; ++i) { + auto scope = state.make_region_scope(getBody()); + for (BlockArgument arg : getBody().front().getArguments()) { + if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) + return DiagnosedSilenceableFailure::definiteFailure(); + } + for (Operation &op : getBody().front().without_terminator()) { + DiagnosedSilenceableFailure diag = + state.applyTransform(cast(op)); + if (!diag.succeeded()) + return diag; + } + } + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult mlir::test::TestReEnterRegionOp::verify() { + if (getNumOperands() != getBody().front().getNumArguments()) { + return emitOpError() << "expects as many operands as block arguments"; + } + return success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -553,4 +553,15 @@ let cppNamespace = "::mlir::test"; } +def TestReEnterRegionOp + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins Variadic:$args); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "($args^ `:` type($args))? attr-dict-with-keyword regions"; + let cppNamespace = "::mlir::test"; + let hasVerifier = 1; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD