diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -285,7 +285,8 @@ /// transform IR region and payload IR objects. RegionScope(TransformState &state, Region ®ion) : state(state), region(®ion) { - auto res = state.mappings.insert(std::make_pair(®ion, Mappings())); + auto res = state.mappings.insert( + std::make_pair(®ion, std::make_unique())); assert(res.second && "the region scope is already present"); (void)res; #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -437,7 +438,7 @@ } } #endif // NDEBUG - return it->second; + return *it->second.get(); } /// Returns the mappings frame for the region in which the operation resides. @@ -464,7 +465,7 @@ } } #endif // NDEBUG - return it->second; + return *it->second.get(); } /// Updates the state to include the associations between op results and the @@ -683,7 +684,10 @@ /// A stack of mappings between transform IR values and payload IR ops, /// aggregated by the region in which the transform IR values are defined. - llvm::MapVector mappings; + /// We use a pointer to the Mappings struct so that reallocations inside + /// MapVector don't invalidate iterators when we apply nested transform ops + /// while also iterating over the mappings. + llvm::MapVector> mappings; /// Op handles may be temporarily mapped to nullptr to avoid invalidating /// payload op iterators. This set contains all op handles with nullptrs. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -45,7 +45,8 @@ for (ArrayRef mapping : extraMappings) topLevelMappedValues.push_back(mapping); - auto result = mappings.insert(std::make_pair(region, Mappings())); + auto result = + mappings.insert(std::make_pair(region, std::make_unique())); assert(result.second && "the region scope is already present"); (void)result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -87,8 +88,8 @@ bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { - auto iterator = mapping.reverse.find(op); - if (iterator != mapping.reverse.end()) { + auto iterator = mapping->reverse.find(op); + if (iterator != mapping->reverse.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } @@ -106,8 +107,8 @@ bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { - auto iterator = mapping.reverseValues.find(payloadValue); - if (iterator != mapping.reverseValues.end()) { + auto iterator = mapping->reverseValues.find(payloadValue); + if (iterator != mapping->reverseValues.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } @@ -611,7 +612,7 @@ // Go over all op handle mappings and mark as invalidated any handle // pointing to any of the payload ops associated with the given handle or // any op nested in them. - for (const auto &[payloadOp, otherHandles] : mapping.reverse) { + for (const auto &[payloadOp, otherHandles] : mapping->reverse) { for (Value otherHandle : otherHandles) recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, otherHandle, throughValue, @@ -622,7 +623,7 @@ // or any op nested in them. Similarly invalidate handles to argument of // blocks belonging to any region of any payload op associated with the // given handle or any op nested in them. - for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { + for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) { for (Value valueHandle : valueHandles) recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, payloadValue, valueHandle, @@ -842,8 +843,9 @@ // Cache Operation* -> OperationName mappings. These will be checked after // the transform has been applied to detect incorrect memory side effects // and missing op tracking. - for (Mappings &mapping : llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping.reverse)) { + for (std::unique_ptr &mapping : + llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping->reverse)) { auto insertion = cachedNames.insert({op, op->getName()}); if (!insertion.second) { if (insertion.first->second != op->getName()) { @@ -993,8 +995,9 @@ } // Check cached operation names. - for (Mappings &mapping : llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping.reverse)) { + for (std::unique_ptr &mapping : + llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping->reverse)) { // Make sure that the name of the op has not changed. If it has changed, // the op was removed and a new op was allocated at the same memory // location. This means that we are missing op tracking somewhere. @@ -1106,7 +1109,7 @@ // Remember pointers to payload ops referenced by the handles going out of // scope. SmallVector referencedOps = - llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse)); + llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS state.mappings.erase(region);