diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h --- a/mlir/include/mlir/IR/SubElementInterfaces.h +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -19,10 +19,114 @@ #include "mlir/IR/Visitors.h" namespace mlir { -template -using SubElementReplFn = function_ref; -template -using SubElementResultReplFn = function_ref(T)>; +//===----------------------------------------------------------------------===// +/// AttrTypeReplacer +//===----------------------------------------------------------------------===// + +/// This class provides a utility for replacing attributes/types, and their sub +/// elements. Multiple replacement functions may be registered. +class AttrTypeReplacer { +public: + //===--------------------------------------------------------------------===// + // Application + //===--------------------------------------------------------------------===// + + /// Replace the elements within the given operation. By default this includes + /// the attributes within the operation. If `replaceLocs` is true, this also + /// updates its location, the locations of any nested block arguments. If + /// `replaceTypes` is true, this also updates the result types of the + /// operation, and the types of any nested block arguments. + void replaceElementsIn(Operation *op, bool replaceLocs = false, + bool replaceTypes = false); + + /// Replace the given attribute/type, and recursively replace any sub + /// elements. Returns either the new attribute/type, or nullptr in the case of + /// failure. + Attribute replace(Attribute attr); + Type replace(Type type); + + //===--------------------------------------------------------------------===// + // Registration + //===--------------------------------------------------------------------===// + + /// A replacement mapping function, which returns either None (to signal the + /// element wasn't handled), or a pair of the replacement element and a + /// WalkResult. + template + using ReplaceFnResult = Optional>; + template + using ReplaceFn = std::function(T)>; + + /// Register a replacement function for mapping a given attribute or type. A + /// replacement function must be convertible to any of the following + /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT` + /// is either `Type` or `Attribute` respectively): + /// + /// * Optional(T) + /// - This either returns a valid Attribute/Type in the case of success, + /// nullptr in the case of failure, or `llvm::None` to signify that + /// additional replacement functions may be applied (i.e. this function + /// doesn't handle that instance). + /// + /// * Optional>(T) + /// - Similar to the above, but also allows specifying a WalkResult to + /// control the replacement of sub elements of a given attribute or + /// type. Returning a `skip` result, for example, will not recursively + /// process the resultant attribute or type value. + /// + /// Note: When replacing, the mostly recently added replacement functions will + /// be invoked first. + void addReplacement(ReplaceFn fn) { + attrReplacementFns.emplace_back(std::move(fn)); + } + void addReplacement(ReplaceFn fn) { + typeReplacementFns.push_back(std::move(fn)); + } + + /// Register a replacement function that doesn't match the default signature, + /// either because it uses a derived parameter type, or it uses a simplified + /// result type. + template >::template arg_t<0>, + typename BaseT = std::conditional_t, + Attribute, Type>, + typename ResultT = std::invoke_result_t> + std::enable_if_t || + !std::is_convertible_v>> + addReplacement(FnT &&callback) { + addReplacement([callback = std::forward(callback)]( + BaseT base) -> ReplaceFnResult { + if (auto derived = dyn_cast(base)) { + if constexpr (std::is_convertible_v>) { + Optional result = callback(derived); + return result ? std::make_pair(*result, WalkResult::advance()) + : ReplaceFnResult(); + } else { + return callback(derived); + } + } + return ReplaceFnResult(); + }); + } + +private: + /// Internal implementation of the `replace` methods above. + template + T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap &map); + + /// Replace the sub elements of the given interface. + template + T replaceSubElements(InterfaceT interface, DenseMap &interfaceMap); + + /// The set of replacement functions that map sub elements. + std::vector> attrReplacementFns; + std::vector> typeReplacementFns; + + /// The set of cached mappings for attributes/types. + DenseMap attrMap; + DenseMap typeMap; +}; //===----------------------------------------------------------------------===// /// AttrTypeSubElementHandler @@ -291,7 +395,7 @@ } // namespace detail } // namespace mlir -/// Include the definitions of the sub elemnt interfaces. +/// Include the definitions of the sub element interfaces. #include "mlir/IR/SubElementAttrInterfaces.h.inc" #include "mlir/IR/SubElementTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td --- a/mlir/include/mlir/IR/SubElementInterfaces.td +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -66,25 +66,14 @@ llvm::function_ref walkTypesFn); /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementReplFn replaceAttrFn, - mlir::SubElementReplFn replaceTypeFn - ) { - return replaceSubElements( - [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); }, - [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); } - ); + /// provided map functions. Returns nullptr in the case of failure. See + /// `AttrTypeReplacer` for information on the support replacement function types. + template + }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) { + AttrTypeReplacer replacer; + (replacer.addReplacement(std::forward(replacementFns)), ...); + return replacer.replace(*this); } - /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. This variant allows for the map function to return an - /// additional walk result. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceSubElements( - llvm::function_ref< - std::pair(mlir::Attribute)> replaceAttrFn, - llvm::function_ref< - std::pair(mlir::Type)> replaceTypeFn - ); }]; code extraTraitClassDeclaration = [{ /// Walk all of the held sub-attributes and sub-types. @@ -95,18 +84,13 @@ } /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementReplFn replaceAttrFn, - mlir::SubElementReplFn replaceTypeFn) { - }] # interfaceName # " interface(" # derivedValue # [{); - return interface.replaceSubElements(replaceAttrFn, replaceTypeFn); - } - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementResultReplFn replaceAttrFn, - mlir::SubElementResultReplFn replaceTypeFn) { - }] # interfaceName # " interface(" # derivedValue # [{); - return interface.replaceSubElements(replaceAttrFn, replaceTypeFn); + /// provided map functions. Returns nullptr in the case of failure. See + /// `AttrTypeReplacer` for information on the support replacement function types. + template + }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) { + AttrTypeReplacer replacer; + (replacer.addReplacement(std::forward(replacementFns)), ...); + return replacer.replace(}] # derivedValue # [{); } }]; code extraSharedClassDeclaration = [{ @@ -118,35 +102,6 @@ void walkSubTypes(llvm::function_ref walkFn) { walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); } - - /// Recursively replace all of the nested sub-attributes using the provided - /// map function. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementReplFn replaceAttrFn) { - return replaceSubElements( - replaceAttrFn, [](mlir::Type type) { return type; }); - } - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementResultReplFn replaceAttrFn) { - return replaceSubElements( - replaceAttrFn, - [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); } - ); - } - /// Recursively replace all of the nested sub-types using the provided map - /// function. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementReplFn replaceTypeFn) { - return replaceSubElements( - [](mlir::Attribute attr) { return attr; }, replaceTypeFn); - } - }] # attrOrType # [{ replaceSubElements( - mlir::SubElementResultReplFn replaceTypeFn) { - return replaceSubElements( - [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); }, - replaceTypeFn - ); - } }]; } diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/SubElementInterfaces.h" +#include "mlir/IR/Operation.h" #include "llvm/ADT/DenseSet.h" @@ -91,116 +92,146 @@ } //===----------------------------------------------------------------------===// -// ReplaceSubElements +/// AttrTypeReplacer +//===----------------------------------------------------------------------===// -template -static void updateSubElementImpl( - T element, function_ref(T)> walkFn, - DenseMap &visited, SmallVectorImpl &newElements, - FailureOr &changed, ReplaceSubElementFnT &&replaceSubElementFn) { - // Bail early if we failed at any point. - if (failed(changed)) - return; - newElements.push_back(element); +void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceLocs, + bool replaceTypes) { + // Functor that replaces the given element if the new value is different, + // otherwise returns nullptr. + auto replaceIfDifferent = [&](auto element) { + auto replacement = replace(element); + return (replacement && replacement != element) ? replacement : nullptr; + }; + // Check the attribute dictionary for replacements. + if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) + op->setAttrs(cast(newAttrs)); - // Guard against potentially null inputs. We always map null to null. - if (!element) + // If we aren't updating locations or types, we're done. + if (!replaceTypes && !replaceLocs) return; - // Check for an existing mapping for this element, and walk it if we haven't - // yet. - T *mappedElement = &visited[element]; - if (!*mappedElement) { - WalkResult result = WalkResult::advance(); - std::tie(*mappedElement, result) = walkFn(element); - - // Try walking this element. - if (result.wasInterrupted() || !*mappedElement) { - changed = failure(); - return; - } + // Update the location. + if (replaceLocs) { + if (Attribute newLoc = replaceIfDifferent(op->getLoc())) + op->setLoc(cast(newLoc)); + } - // Handle replacing sub-elements if this element is also a container. - if (!result.wasSkipped()) { - if (auto interface = mappedElement->template dyn_cast()) { - // Cache the size of the `visited` map since it may grow when calling - // `replaceSubElementFn` and we would need to fetch again the (now - // invalidated) reference to `mappedElement`. - size_t visitedSize = visited.size(); - auto replacedElement = replaceSubElementFn(interface); - if (!replacedElement) { - changed = failure(); - return; + // Update the result types. + if (replaceTypes) { + for (OpResult result : op->getResults()) + if (Type newType = replaceIfDifferent(result.getType())) + result.setType(newType); + } + + // Update any nested block arguments. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument &arg : block.getArguments()) { + if (replaceLocs) { + if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) + arg.setLoc(cast(newLoc)); + } + + if (replaceTypes) { + if (Type newType = replaceIfDifferent(arg.getType())) + arg.setType(newType); } - if (visitedSize != visited.size()) - mappedElement = &visited[element]; - *mappedElement = replacedElement; } } } +} + +template +static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, + DenseMap &elementMap, + SmallVectorImpl &newElements, + FailureOr &changed) { + // Bail early if we failed at any point. + if (failed(changed)) + return; + + // Guard against potentially null inputs. We always map null to null. + if (!element) { + newElements.push_back(nullptr); + return; + } - // Update to the mapped element. - if (*mappedElement != element) { - newElements.back() = *mappedElement; - changed = true; + // Replace the element. + if (T result = replacer.replace(element)) { + newElements.push_back(result); + if (result != element) + changed = true; + } else { + changed = failure(); } } -template -static typename InterfaceT::ValueType -replaceSubElementsImpl(InterfaceT interface, - SubElementResultReplFn walkAttrsFn, - SubElementResultReplFn walkTypesFn, - DenseMap &visitedAttrs, - DenseMap &visitedTypes) { +template +T AttrTypeReplacer::replaceSubElements(InterfaceT interface, + DenseMap &interfaceMap) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; - auto replaceSubElementFn = [&](auto subInterface) { - return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn, - visitedAttrs, visitedTypes); - }; interface.walkImmediateSubElements( [&](Attribute element) { - updateSubElementImpl( - element, walkAttrsFn, visitedAttrs, newAttrs, changed, - replaceSubElementFn); + updateSubElementImpl(element, *this, attrMap, newAttrs, changed); }, [&](Type element) { - updateSubElementImpl( - element, walkTypesFn, visitedTypes, newTypes, changed, - replaceSubElementFn); + updateSubElementImpl(element, *this, typeMap, newTypes, changed); }); if (failed(changed)) - return {}; + return nullptr; - // If the sub-elements didn't change, just return the original value. - if (!*changed) - return interface; + // If any sub-elements changed, use the new elements during the replacement. + T result = interface; + if (*changed) + result = interface.replaceImmediateSubElements(newAttrs, newTypes); + return result; +} + +/// Shared implementation of replacing a given attribute or type element. +template +T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns, + DenseMap &map) { + auto [it, inserted] = map.try_emplace(element, element); + if (!inserted) + return it->second; + + T result = element; + WalkResult walkResult = WalkResult::advance(); + for (auto &replaceFn : llvm::reverse(replaceFns)) { + if (Optional> newRes = replaceFn(element)) { + std::tie(result, walkResult) = *newRes; + break; + } + } + + // If an error occurred, return nullptr to indicate failure. + if (walkResult.wasInterrupted() || !result) + return map[element] = nullptr; + + // Handle replacing sub-elements if this element is also a container. + if (!walkResult.wasSkipped()) { + if (auto interface = dyn_cast(result)) { + // Replace the sub elements of this element, bailing if we fail. + if (!(result = replaceSubElements(interface, map))) + return map[element] = nullptr; + } + } - // Use the new elements during the replacement. - return interface.replaceImmediateSubElements(newAttrs, newTypes); + return map[element] = result; } -Attribute SubElementAttrInterface::replaceSubElements( - SubElementResultReplFn replaceAttrFn, - SubElementResultReplFn replaceTypeFn) { - assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); - DenseMap visitedAttrs; - DenseMap visitedTypes; - return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn, - visitedAttrs, visitedTypes); +Attribute AttrTypeReplacer::replace(Attribute attr) { + return replaceImpl(attr, attrReplacementFns, + attrMap); } -Type SubElementTypeInterface::replaceSubElements( - SubElementResultReplFn replaceAttrFn, - SubElementResultReplFn replaceTypeFn) { - assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); - DenseMap visitedAttrs; - DenseMap visitedTypes; - return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn, - visitedAttrs, visitedTypes); +Type AttrTypeReplacer::replace(Type type) { + return replaceImpl(type, typeReplacementFns, + typeMap); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -853,40 +853,31 @@ for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr oldAttr = scope.symbol; SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); - - auto walkFn = [&](Operation *op) -> Optional { - auto remapAttrFn = - [&](Attribute attr) -> std::pair { - // Regardless of the match, don't walk nested SymbolRefAttrs, we don't - // want to accidentally replace an inner reference. - if (attr == oldAttr) - return {newAttr, WalkResult::skip()}; - // Handle prefix matches. - if (SymbolRefAttr symRef = attr.dyn_cast()) { - if (isReferencePrefixOf(oldAttr, symRef)) { + AttrTypeReplacer replacer; + replacer.addReplacement( + [&](SymbolRefAttr attr) -> std::pair { + // Regardless of the match, don't walk nested SymbolRefAttrs, we don't + // want to accidentally replace an inner reference. + if (attr == oldAttr) + return {newAttr, WalkResult::skip()}; + // Handle prefix matches. + if (isReferencePrefixOf(oldAttr, attr)) { auto oldNestedRefs = oldAttr.getNestedReferences(); - auto nestedRefs = symRef.getNestedReferences(); + auto nestedRefs = attr.getNestedReferences(); if (oldNestedRefs.empty()) return {SymbolRefAttr::get(newSymbol, nestedRefs), WalkResult::skip()}; auto newNestedRefs = llvm::to_vector<4>(nestedRefs); newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; - return { - SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs), - WalkResult::skip()}; + return {SymbolRefAttr::get(attr.getRootReference(), newNestedRefs), + WalkResult::skip()}; } return {attr, WalkResult::skip()}; - } - return {attr, WalkResult::advance()}; - }; - // Generate a new attribute dictionary by replacing references to the old - // symbol. - auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn); - if (!newDict) - return WalkResult::interrupt(); - - op->setAttrs(newDict.template cast()); + }); + + auto walkFn = [&](Operation *op) -> Optional { + replacer.replaceElementsIn(op); return WalkResult::advance(); }; if (!scope.walkSymbolTable(walkFn))