diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h --- a/mlir/include/mlir/IR/SubElementInterfaces.h +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -16,6 +16,14 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Visitors.h" + +namespace mlir { +template +using SubElementReplFn = function_ref; +template +using SubElementResultReplFn = function_ref(T)>; +} // namespace mlir /// Include the definitions of the sub elemnt interfaces. #include "mlir/IR/SubElementAttrInterfaces.h.inc" diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td --- a/mlir/include/mlir/IR/SubElementInterfaces.td +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -56,8 +56,22 @@ /// Recursively replace all of the nested sub-attributes and sub-types using the /// provided map functions. Returns nullptr in the case of failure. }] # attrOrType # [{ replaceSubElements( - llvm::function_ref replaceAttrFn, - llvm::function_ref replaceTypeFn + mlir::SubElementReplFn replaceAttrFn, + mlir::SubElementReplFn replaceTypeFn + ) { + return replaceSubElements( + [&](Attribute attr) { return std::make_pair(replaceAttrFn(attr), WalkResult::advance()); }, + [&](Type type) { return std::make_pair(replaceTypeFn(type), WalkResult::advance()); } + ); + } + /// Recursively replace all of the nested sub-attributes and sub-types using the + /// provided map functions. This variant allows for the map function to return an + /// additional walk result. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceSubElements( + llvm::function_ref< + std::pair(mlir::Attribute)> replaceAttrFn, + llvm::function_ref< + std::pair(mlir::Type)> replaceTypeFn ); }]; code extraTraitClassDeclaration = [{ @@ -71,18 +85,16 @@ /// Recursively replace all of the nested sub-attributes and sub-types using the /// provided map functions. Returns nullptr in the case of failure. }] # attrOrType # [{ replaceSubElements( - llvm::function_ref replaceAttrFn, - llvm::function_ref replaceTypeFn) { + mlir::SubElementReplFn replaceAttrFn, + mlir::SubElementReplFn replaceTypeFn) { }] # interfaceName # " interface(" # derivedValue # [{); return interface.replaceSubElements(replaceAttrFn, replaceTypeFn); } - - /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. Returns nullptr in the case of failure. - }] # attrOrType # [{ replaceImmediateSubElements( - llvm::ArrayRef replAttrs, - llvm::function_ref replTypes) { - return nullptr; + }] # attrOrType # [{ replaceSubElements( + mlir::SubElementResultReplFn replaceAttrFn, + mlir::SubElementResultReplFn replaceTypeFn) { + }] # interfaceName # " interface(" # derivedValue # [{); + return interface.replaceSubElements(replaceAttrFn, replaceTypeFn); } }]; code extraSharedClassDeclaration = [{ @@ -98,17 +110,31 @@ /// Recursively replace all of the nested sub-attributes using the provided /// map function. Returns nullptr in the case of failure. }] # attrOrType # [{ replaceSubElements( - llvm::function_ref replaceAttrFn) { + mlir::SubElementReplFn replaceAttrFn) { return replaceSubElements( replaceAttrFn, [](mlir::Type type) { return type; }); } + }] # attrOrType # [{ replaceSubElements( + mlir::SubElementResultReplFn replaceAttrFn) { + return replaceSubElements( + replaceAttrFn, + [](mlir::Type type) { return std::make_pair(type, WalkResult::advance()); } + ); + } /// Recursively replace all of the nested sub-types using the provided map /// function. Returns nullptr in the case of failure. }] # attrOrType # [{ replaceSubElements( - llvm::function_ref replaceTypeFn) { + mlir::SubElementReplFn replaceTypeFn) { return replaceSubElements( [](mlir::Attribute attr) { return attr; }, replaceTypeFn); } + }] # attrOrType # [{ replaceSubElements( + mlir::SubElementResultReplFn replaceTypeFn) { + return replaceSubElements( + [](mlir::Attribute attr) { return std::make_pair(attr, WalkResult::advance()); }, + replaceTypeFn + ); + } }]; } diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -102,11 +102,10 @@ } template -static void updateSubElementImpl(T element, function_ref walkFn, - DenseMap &visited, - SmallVectorImpl &newElements, - FailureOr &changed, - ReplaceSubElementFnT &&replaceSubElementFn) { +static void updateSubElementImpl( + T element, function_ref(T)> walkFn, + DenseMap &visited, SmallVectorImpl &newElements, + FailureOr &changed, ReplaceSubElementFnT &&replaceSubElementFn) { // Bail early if we failed at any point. if (failed(changed)) return; @@ -120,17 +119,22 @@ // yet. T &mappedElement = visited[element]; if (!mappedElement) { + WalkResult result = WalkResult::advance(); + std::tie(mappedElement, result) = walkFn(element); + // Try walking this element. - if (!(mappedElement = walkFn(element))) { + if (result.wasInterrupted() || !mappedElement) { changed = failure(); return; } // Handle replacing sub-elements if this element is also a container. - if (auto interface = mappedElement.template dyn_cast()) { - if (!(mappedElement = replaceSubElementFn(interface))) { - changed = failure(); - return; + if (!result.wasSkipped()) { + if (auto interface = mappedElement.template dyn_cast()) { + if (!(mappedElement = replaceSubElementFn(interface))) { + changed = failure(); + return; + } } } } @@ -145,8 +149,8 @@ template static typename InterfaceT::ValueType replaceSubElementsImpl(InterfaceT interface, - function_ref walkAttrsFn, - function_ref walkTypesFn, + SubElementResultReplFn walkAttrsFn, + SubElementResultReplFn walkTypesFn, DenseMap &visitedAttrs, DenseMap &visitedTypes) { // Walk the current sub-elements, replacing them as necessary. @@ -186,8 +190,8 @@ } Attribute SubElementAttrInterface::replaceSubElements( - function_ref replaceAttrFn, - function_ref replaceTypeFn) { + SubElementResultReplFn replaceAttrFn, + SubElementResultReplFn replaceTypeFn) { assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); DenseMap visitedAttrs; DenseMap visitedTypes; @@ -196,8 +200,8 @@ } Type SubElementTypeInterface::replaceSubElements( - function_ref replaceAttrFn, - function_ref replaceTypeFn) { + SubElementResultReplFn replaceAttrFn, + SubElementResultReplFn replaceTypeFn) { assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); DenseMap visitedAttrs; DenseMap visitedTypes; diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -853,23 +853,30 @@ SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](Operation *op) -> Optional { - auto remapAttrFn = [&](Attribute attr) -> Attribute { + auto remapAttrFn = + [&](Attribute attr) -> std::pair { + // Regardless of the match, don't walk nested SymbolRefAttrs, we don't + // want to accidentally replace an inner reference. if (attr == oldAttr) - return newAttr; + return {newAttr, WalkResult::skip()}; // Handle prefix matches. if (SymbolRefAttr symRef = attr.dyn_cast()) { if (isReferencePrefixOf(oldAttr, symRef)) { auto oldNestedRefs = oldAttr.getNestedReferences(); auto nestedRefs = symRef.getNestedReferences(); if (oldNestedRefs.empty()) - return SymbolRefAttr::get(newSymbol, nestedRefs); + return {SymbolRefAttr::get(newSymbol, nestedRefs), + WalkResult::skip()}; auto newNestedRefs = llvm::to_vector<4>(nestedRefs); newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; - return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs); + return { + SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs), + WalkResult::skip()}; } + return {attr, WalkResult::skip()}; } - return attr; + return {attr, WalkResult::advance()}; }; // Generate a new attribute dictionary by replacing references to the old // symbol. diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir --- a/mlir/test/IR/test-symbol-rauw.mlir +++ b/mlir/test/IR/test-symbol-rauw.mlir @@ -94,3 +94,19 @@ } : () -> () } } + +// ----- + +module { + // CHECK: module @replaced_foo + module @foo attributes {sym.new_name = "replaced_foo" } { + // CHECK: func.func private @foo + func.func private @foo() + } + + // CHECK: foo.op + // CHECK-SAME: use = @replaced_foo::@foo + "foo.op"() { + use = @foo::@foo + } : () -> () +}