diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -285,7 +285,8 @@ /// transform IR region and payload IR objects. RegionScope(TransformState &state, Region ®ion) : state(state), region(®ion) { - auto res = state.mappings.insert(std::make_pair(®ion, Mappings())); + auto res = state.mappings.insert( + std::make_pair(®ion, std::make_shared())); assert(res.second && "the region scope is already present"); (void)res; #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -417,11 +418,11 @@ /// Returns the mappings frame for the region in which the value is defined. /// If `allowOutOfScope` is set to "false", asserts that the value is in /// scope, based on the current stack of frames. - const Mappings &getMapping(Value value, bool allowOutOfScope = false) const { + const Mappings *getMapping(Value value, bool allowOutOfScope = false) const { return const_cast(this)->getMapping(value, allowOutOfScope); } - Mappings &getMapping(Value value, bool allowOutOfScope = false) { + Mappings *getMapping(Value value, bool allowOutOfScope = false) { Region *region = value.getParentRegion(); auto it = mappings.find(region); assert(it != mappings.end() && @@ -437,18 +438,18 @@ } } #endif // NDEBUG - return it->second; + return it->second.get(); } /// Returns the mappings frame for the region in which the operation resides. /// If `allowOutOfScope` is set to "false", asserts that the operation is in /// scope, based on the current stack of frames. - const Mappings &getMapping(Operation *operation, + const Mappings *getMapping(Operation *operation, bool allowOutOfScope = false) const { return const_cast(this)->getMapping(operation, allowOutOfScope); } - Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) { + Mappings *getMapping(Operation *operation, bool allowOutOfScope = false) { Region *region = operation->getParentRegion(); auto it = mappings.find(region); assert(it != mappings.end() && @@ -464,7 +465,7 @@ } } #endif // NDEBUG - return it->second; + return it->second.get(); } /// Updates the state to include the associations between op results and the @@ -683,7 +684,7 @@ /// A stack of mappings between transform IR values and payload IR ops, /// aggregated by the region in which the transform IR values are defined. - llvm::MapVector mappings; + llvm::MapVector> mappings; /// Op handles may be temporarily mapped to nullptr to avoid invalidating /// payload op iterators. This set contains all op handles with nullptrs. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -45,7 +45,8 @@ for (ArrayRef mapping : extraMappings) topLevelMappedValues.push_back(mapping); - auto result = mappings.insert(std::make_pair(region, Mappings())); + auto result = + mappings.insert(std::make_pair(region, std::make_shared())); assert(result.second && "the region scope is already present"); (void)result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -57,7 +58,7 @@ ArrayRef transform::TransformState::getPayloadOpsView(Value value) const { - const TransformOpMapping &operationMapping = getMapping(value).direct; + const TransformOpMapping &operationMapping = getMapping(value)->direct; auto iter = operationMapping.find(value); assert( iter != operationMapping.end() && @@ -66,7 +67,7 @@ } ArrayRef transform::TransformState::getParams(Value value) const { - const ParamMapping &mapping = getMapping(value).params; + const ParamMapping &mapping = getMapping(value)->params; auto iter = mapping.find(value); assert(iter != mapping.end() && "cannot find mapping for param handle " "(operation/value handle provided?)"); @@ -75,7 +76,7 @@ ArrayRef transform::TransformState::getPayloadValues(Value handleValue) const { - const ValueMapping &mapping = getMapping(handleValue).values; + const ValueMapping &mapping = getMapping(handleValue)->values; auto iter = mapping.find(handleValue); assert(iter != mapping.end() && "cannot find mapping for value handle " "(param/operation handle provided?)"); @@ -87,8 +88,8 @@ bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { - auto iterator = mapping.reverse.find(op); - if (iterator != mapping.reverse.end()) { + auto iterator = mapping->reverse.find(op); + if (iterator != mapping->reverse.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } @@ -106,8 +107,8 @@ bool includeOutOfScope) const { bool found = false; for (const auto &[region, mapping] : llvm::reverse(mappings)) { - auto iterator = mapping.reverseValues.find(payloadValue); - if (iterator != mapping.reverseValues.end()) { + auto iterator = mapping->reverseValues.find(payloadValue); + if (iterator != mapping->reverseValues.end()) { llvm::append_range(handles, iterator->getSecond()); found = true; } @@ -218,14 +219,14 @@ // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. SmallVector storedTargets(targets.begin(), targets.end()); - Mappings &mappings = getMapping(value); + Mappings *mappings = getMapping(value); bool inserted = - mappings.direct.insert({value, std::move(storedTargets)}).second; + mappings->direct.insert({value, std::move(storedTargets)}).second; assert(inserted && "value is already associated with another list"); (void)inserted; for (Operation *op : targets) - mappings.reverse[op].push_back(value); + mappings->reverse[op].push_back(value); return success(); } @@ -251,16 +252,16 @@ if (failed(result.checkAndReport())) return failure(); - Mappings &mappings = getMapping(handle); + Mappings *mappings = getMapping(handle); bool inserted = - mappings.values.insert({handle, std::move(payloadValueVector)}).second; + mappings->values.insert({handle, std::move(payloadValueVector)}).second; assert( inserted && "value handle is already associated with another list of payload values"); (void)inserted; for (Value payload : payloadValues) - mappings.reverseValues[payload].push_back(handle); + mappings->reverseValues[payload].push_back(handle); return success(); } @@ -284,9 +285,9 @@ if (failed(result.checkAndReport())) return failure(); - Mappings &mappings = getMapping(value); + Mappings *mappings = getMapping(value); bool inserted = - mappings.params.insert({value, llvm::to_vector(params)}).second; + mappings->params.insert({value, llvm::to_vector(params)}).second; assert(inserted && "value is already associated with another list of params"); (void)inserted; return success(); @@ -305,36 +306,36 @@ void transform::TransformState::forgetMapping(Value opHandle, ValueRange origOpFlatResults) { - Mappings &mappings = getMapping(opHandle); - for (Operation *op : mappings.direct[opHandle]) - dropMappingEntry(mappings.reverse, op, opHandle); - mappings.direct.erase(opHandle); + Mappings *mappings = getMapping(opHandle); + for (Operation *op : mappings->direct[opHandle]) + dropMappingEntry(mappings->reverse, op, opHandle); + mappings->direct.erase(opHandle); for (Value opResult : origOpFlatResults) { SmallVector resultHandles; (void)getHandlesForPayloadValue(opResult, resultHandles); for (Value resultHandle : resultHandles) { - Mappings &localMappings = getMapping(resultHandle); - dropMappingEntry(localMappings.values, resultHandle, opResult); - dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); + Mappings *localMappings = getMapping(resultHandle); + dropMappingEntry(localMappings->values, resultHandle, opResult); + dropMappingEntry(localMappings->reverseValues, opResult, resultHandle); } } } void transform::TransformState::forgetValueMapping( Value valueHandle, ArrayRef payloadOperations) { - Mappings &mappings = getMapping(valueHandle); - for (Value payloadValue : mappings.reverseValues[valueHandle]) - dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); - mappings.values.erase(valueHandle); + Mappings *mappings = getMapping(valueHandle); + for (Value payloadValue : mappings->reverseValues[valueHandle]) + dropMappingEntry(mappings->reverseValues, payloadValue, valueHandle); + mappings->values.erase(valueHandle); for (Operation *payloadOp : payloadOperations) { SmallVector opHandles; (void)getHandlesForPayloadOp(payloadOp, opHandles); for (Value opHandle : opHandles) { - Mappings &localMappings = getMapping(opHandle); - dropMappingEntry(localMappings.direct, opHandle, payloadOp); - dropMappingEntry(localMappings.reverse, payloadOp, opHandle); + Mappings *localMappings = getMapping(opHandle); + dropMappingEntry(localMappings->direct, opHandle, payloadOp); + dropMappingEntry(localMappings->reverse, payloadOp, opHandle); } } } @@ -359,8 +360,8 @@ if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true))) return failure(); for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); - dropMappingEntry(mappings.reverse, op, handle); + Mappings *mappings = getMapping(handle, /*allowOutOfScope=*/true); + dropMappingEntry(mappings->reverse, op, handle); } #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -390,9 +391,9 @@ // element from an array invalidates iterators; merely changing the value of // elements does not. for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); - auto it = mappings.direct.find(handle); - if (it == mappings.direct.end()) + Mappings *mappings = getMapping(handle, /*allowOutOfScope=*/true); + auto it = mappings->direct.find(handle); + if (it == mappings->direct.end()) continue; SmallVector &association = it->getSecond(); @@ -403,7 +404,7 @@ } if (replacement) { - mappings.reverse[replacement].push_back(handle); + mappings->reverse[replacement].push_back(handle); } else { opHandlesToCompact.insert(handle); } @@ -420,16 +421,16 @@ return failure(); for (Value handle : valueHandles) { - Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); - dropMappingEntry(mappings.reverseValues, value, handle); + Mappings *mappings = getMapping(handle, /*allowOutOfScope=*/true); + dropMappingEntry(mappings->reverseValues, value, handle); // If replacing with null, that is erasing the mapping, drop the mapping // between the handles and the IR objects if (!replacement) { - dropMappingEntry(mappings.values, handle, value); + dropMappingEntry(mappings->values, handle, value); } else { - auto it = mappings.values.find(handle); - if (it == mappings.values.end()) + auto it = mappings->values.find(handle); + if (it == mappings->values.end()) continue; SmallVector &association = it->getSecond(); @@ -437,7 +438,7 @@ if (mapped == value) mapped = replacement; } - mappings.reverseValues[replacement].push_back(handle); + mappings->reverseValues[replacement].push_back(handle); } } @@ -611,7 +612,7 @@ // Go over all op handle mappings and mark as invalidated any handle // pointing to any of the payload ops associated with the given handle or // any op nested in them. - for (const auto &[payloadOp, otherHandles] : mapping.reverse) { + for (const auto &[payloadOp, otherHandles] : mapping->reverse) { for (Value otherHandle : otherHandles) recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, otherHandle, throughValue, @@ -622,7 +623,7 @@ // or any op nested in them. Similarly invalidate handles to argument of // blocks belonging to any region of any payload op associated with the // given handle or any op nested in them. - for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { + for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) { for (Value valueHandle : valueHandles) recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, payloadValue, valueHandle, @@ -772,8 +773,8 @@ void transform::TransformState::compactOpHandles() { for (Value handle : opHandlesToCompact) { - Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); - llvm::erase_value(mappings.direct[handle], nullptr); + Mappings *mappings = getMapping(handle, /*allowOutOfScope=*/true); + llvm::erase_value(mappings->direct[handle], nullptr); } opHandlesToCompact.clear(); } @@ -842,8 +843,9 @@ // Cache Operation* -> OperationName mappings. These will be checked after // the transform has been applied to detect incorrect memory side effects // and missing op tracking. - for (Mappings &mapping : llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping.reverse)) { + for (std::shared_ptr mapping : + llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping->reverse)) { auto insertion = cachedNames.insert({op, op->getName()}); if (!insertion.second) { if (insertion.first->second != op->getName()) { @@ -993,8 +995,9 @@ } // Check cached operation names. - for (Mappings &mapping : llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping.reverse)) { + for (std::shared_ptr mapping : + llvm::make_second_range(mappings)) { + for (Operation *op : llvm::make_first_range(mapping->reverse)) { // Make sure that the name of the op has not changed. If it has changed, // the op was removed and a new op was allocated at the same memory // location. This means that we are missing op tracking somewhere. @@ -1106,7 +1109,7 @@ // Remember pointers to payload ops referenced by the handles going out of // scope. SmallVector referencedOps = - llvm::to_vector(llvm::make_first_range(state.mappings[region].reverse)); + llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS state.mappings.erase(region);