diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -262,12 +262,6 @@ // class body to comply with visibility and full-declaration requirements. inline RegionScope make_region_scope(Region ®ion); - /// Creates a new region scope for the given isolated-from-above region. - /// Unlike the non-isolated counterpart, there is no nesting expectation. - // Implementation note: this method is inline but implemented outside of the - // class body to comply with visibility and full-declaration requirements - inline RegionScope make_isolated_region_scope(Region ®ion); - /// A RAII object maintaining a "stack frame" for a transform IR region. When /// applying a transform IR operation that contains a region, the caller is /// expected to create a RegionScope before applying the ops contained in the @@ -282,51 +276,25 @@ ~RegionScope(); private: - /// Tag structure for differentiating the constructor for isolated regions. - struct Isolated {}; - /// Creates a new scope for mappings between values defined in the given /// transform IR region and payload IR objects. RegionScope(TransformState &state, Region ®ion) : state(state), region(®ion) { - auto res = state.mappings.try_emplace(this->region); + auto res = state.mappings.insert(std::make_pair(®ion, Mappings())); assert(res.second && "the region scope is already present"); (void)res; #if LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(((state.regionStack.size() == 1 && !state.regionStack.back()) || - state.regionStack.back()->isProperAncestor(®ion)) && - "scope started at a non-nested region"); state.regionStack.push_back(®ion); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } - /// Creates a new scope for mappings between values defined in the given - /// isolated-from-above transform IR region and payload IR objects. - RegionScope(TransformState &state, Region ®ion, Isolated) - : state(state), region(®ion) { - // Store the previous mapping stack locally. - storedMappings = llvm::SmallDenseMap(); - storedMappings->swap(state.mappings); - state.mappings.try_emplace(this->region); -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - state.regionStack.push_back(this->region); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - } - /// Back-reference to the transform state. TransformState &state; /// The region this scope is associated with. Region *region; - /// Local copy of the mappings that existed before entering the current - /// region. Used only when the current region is isolated so we don't - /// accidentally look up the values defined outside the isolated region. - std::optional> storedMappings = - std::nullopt; - friend RegionScope TransformState::make_region_scope(Region &); - friend RegionScope TransformState::make_isolated_region_scope(Region &); }; friend class RegionScope; @@ -446,9 +414,19 @@ return const_cast(this)->getMapping(value); } Mappings &getMapping(Value value) { - auto it = mappings.find(value.getParentRegion()); + Region *region = value.getParentRegion(); + auto it = mappings.find(region); assert(it != mappings.end() && "trying to find a mapping for a value from an unmapped region"); +#ifndef NDEBUG + for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { + if (r == region) + break; + if (r->getParentOp()->hasTrait()) + llvm_unreachable( + "trying to get mapping beyond region that is isolated from above"); + } +#endif // NDEBUG return it->second; } @@ -457,9 +435,19 @@ return const_cast(this)->getMapping(operation); } Mappings &getMapping(Operation *operation) { - auto it = mappings.find(operation->getParentRegion()); + Region *region = operation->getParentRegion(); + auto it = mappings.find(region); assert(it != mappings.end() && "trying to find a mapping for an operation from an unmapped region"); +#ifndef NDEBUG + for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) { + if (r == region) + break; + if (r->getParentOp()->hasTrait()) + llvm_unreachable( + "trying to get mapping beyond region that is isolated from above"); + } +#endif // NDEBUG return it->second; } @@ -676,9 +664,9 @@ /// Remove all nullptrs from op handles that were added by `replacePayloadOp`. void compactOpHandles(); - /// The mappings between transform IR values and payload IR ops, aggregated by - /// the region in which the transform IR values are defined. - llvm::SmallDenseMap mappings; + /// A stack of mappings between transform IR values and payload IR ops, + /// aggregated by the region in which the transform IR values are defined. + llvm::MapVector mappings; /// Op handles may be temporarily mapped to nullptr to avoid invalidating /// payload op iterators. This set contains all op handles with nullptrs. @@ -834,14 +822,6 @@ return RegionScope(*this, region); } -/// Creates a RAII object the lifetime of which corresponds to the new mapping -/// for transform IR values defined in the given isolated-from-above region. -/// Values defined in surrounding regions cannot be accessed. -TransformState::RegionScope -TransformState::make_isolated_region_scope(Region ®ion) { - return RegionScope(*this, region, RegionScope::Isolated()); -} - /// A listener that updates a TransformState based on IR modifications. This /// listener can be used during a greedy pattern rewrite to keep the transform /// state up-to-date. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -45,7 +45,7 @@ for (ArrayRef mapping : extraMappings) topLevelMappedValues.push_back(mapping); - auto result = mappings.try_emplace(region); + auto result = mappings.insert(std::make_pair(region, Mappings())); assert(result.second && "the region scope is already present"); (void)result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -85,12 +85,15 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp( Operation *op, SmallVectorImpl &handles) const { bool found = false; - for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const auto &[region, mapping] : llvm::reverse(mappings)) { auto iterator = mapping.reverse.find(op); if (iterator != mapping.reverse.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } + // Stop looking when reaching a region that is isolated from above. + if (region->getParentOp()->hasTrait()) + break; } return success(found); @@ -99,12 +102,15 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue( Value payloadValue, SmallVectorImpl &handles) const { bool found = false; - for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const auto &[region, mapping] : llvm::reverse(mappings)) { auto iterator = mapping.reverseValues.find(payloadValue); if (iterator != mapping.reverseValues.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } + // Stop looking when reaching a region that is isolated from above. + if (region->getParentOp()->hasTrait()) + break; } return success(found); @@ -590,8 +596,10 @@ // number of IR objects (operations and values). Alternatively, we could walk // the IR nested in each payload op associated with the given handle and look // for handles associated with each operation and value. - for (const transform::TransformState::Mappings &mapping : - llvm::make_second_range(mappings)) { + for (const auto &[region, mapping] : llvm::reverse(mappings)) { + // Stop lookup when reaching a region that is isolated from above. + if (region->getParentOp()->hasTrait()) + break; // Go over all op handle mappings and mark as invalidated any handle // pointing to any of the payload ops associated with the given handle or // any op nested in them. @@ -1102,8 +1110,6 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS state.mappings.erase(region); - if (storedMappings.has_value()) - state.mappings.swap(*storedMappings); #if LLVM_ENABLE_ABI_BREAKING_CHECKS // If the last handle to a payload op has gone out of scope, we no longer diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -446,7 +446,7 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, SmallVectorImpl> &mappings) { assert(block.getParent() && "cannot match using a detached block"); - auto matchScope = state.make_isolated_region_scope(*block.getParent()); + auto matchScope = state.make_region_scope(*block.getParent()); if (failed(state.mapBlockArgument(block.getArgument(0), {op}))) return DiagnosedSilenceableFailure::definiteFailure(); @@ -524,7 +524,7 @@ continue; } - auto scope = state.make_isolated_region_scope(action.getFunctionBody()); + auto scope = state.make_region_scope(action.getFunctionBody()); for (auto &&[arg, map] : llvm::zip_equal( action.getFunctionBody().front().getArguments(), mappings)) { if (failed(state.mapBlockArgument(arg, map))) @@ -1029,7 +1029,7 @@ // Map operands to block arguments. SmallVector> mappings; detail::prepareValueMappings(mappings, getOperands(), state); - auto scope = state.make_isolated_region_scope(callee.getBody()); + auto scope = state.make_region_scope(callee.getBody()); for (auto &&[arg, map] : llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) { if (failed(state.mapBlockArgument(arg, map)))