Index: mlir/include/mlir/IR/BuiltinAttributes.td =================================================================== --- mlir/include/mlir/IR/BuiltinAttributes.td +++ mlir/include/mlir/IR/BuiltinAttributes.td @@ -66,7 +66,7 @@ //===----------------------------------------------------------------------===// def Builtin_ArrayAttr : Builtin_Attr<"Array", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "A collection of other Attribute values"; let description = [{ @@ -340,7 +340,7 @@ //===----------------------------------------------------------------------===// def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "An dictionary of named Attribute values"; let description = [{ Index: mlir/include/mlir/IR/SubElementInterfaces.td =================================================================== --- mlir/include/mlir/IR/SubElementInterfaces.td +++ mlir/include/mlir/IR/SubElementInterfaces.td @@ -33,6 +33,19 @@ (ins "llvm::function_ref":$walkAttrsFn, "llvm::function_ref":$walkTypesFn) >, + InterfaceMethod< + /*desc=*/[{ + Replace the attribute identified by the index with the specified value. + The index is derived from the order of the attributes returned by the + attribute callback of `walkImmediateSubElements`. An index of 0 would + replace the very first attribute given by `walkImmediateSubElements` + }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute", + (ins "std::size_t":$index, "::mlir::Attribute":$value), + [{}], + /*defaultImplementation=*/[{ + return {}; + }] + >, ]; code extraClassDeclaration = [{ Index: mlir/lib/IR/BuiltinAttributes.cpp =================================================================== --- mlir/lib/IR/BuiltinAttributes.cpp +++ mlir/lib/IR/BuiltinAttributes.cpp @@ -53,6 +53,13 @@ walkAttrsFn(attr); } +SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(std::size_t index, + Attribute value) const { + auto vector = getValue().vec(); + vector[index] = value; + return get(getContext(), vector); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -215,6 +222,15 @@ walkAttrsFn(attr); } +SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(std::size_t index, + Attribute value) const { + auto vec = getValue().vec(); + vec[index].second = value; + // The above only modifies the mapped value, but not the key, and therefore + // not the order of the elements. It remains sorted + return getWithSorted(getContext(), vec); +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// Index: mlir/lib/IR/SymbolTable.cpp =================================================================== --- mlir/lib/IR/SymbolTable.cpp +++ mlir/lib/IR/SymbolTable.cpp @@ -485,16 +485,30 @@ // A worklist of a container attribute and the current index into the held // attribute list. - SmallVector attrWorklist(1, attrDict); + struct WorklistItem { + SubElementAttrInterface container; + SmallVector immediateSubElements; + + explicit WorklistItem(SubElementAttrInterface container) { + SmallVector subElements; + container.walkImmediateSubElements( + [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {}); + immediateSubElements = std::move(subElements); + } + }; + + SmallVector attrWorklist(1, WorklistItem{attrDict}); SmallVector curAccessChain(1, /*Value=*/-1); // Process the symbol references within the given nested attribute range. - auto processAttrs = [&](int &index, auto attrRange) -> WalkResult { - for (Attribute attr : llvm::drop_begin(attrRange, index)) { + auto processAttrs = [&](int &index, + WorklistItem &worklistItem) -> WalkResult { + for (Attribute attr : + llvm::drop_begin(worklistItem.immediateSubElements, index)) { /// Check for a nested container attribute, these will also need to be /// walked. - if (attr.isa()) { - attrWorklist.push_back(attr); + if (auto interface = attr.dyn_cast()) { + attrWorklist.emplace_back(interface); curAccessChain.push_back(-1); return WalkResult::advance(); } @@ -517,15 +531,12 @@ WalkResult result = WalkResult::advance(); do { - Attribute attr = attrWorklist.back(); + WorklistItem &item = attrWorklist.back(); int &index = curAccessChain.back(); ++index; // Process the given attribute, which is guaranteed to be a container. - if (auto dict = attr.dyn_cast()) - result = processAttrs(index, make_second_range(dict.getValue())); - else - result = processAttrs(index, attr.cast().getValue()); + result = processAttrs(index, item); } while (!attrWorklist.empty() && !result.wasInterrupted()); return result; } @@ -811,48 +822,45 @@ /// Rebuild the given attribute container after replacing all references to a /// symbol with the updated attribute in 'accesses'. -static Attribute rebuildAttrAfterRAUW( - Attribute container, +static SubElementAttrInterface rebuildAttrAfterRAUW( + SubElementAttrInterface container, ArrayRef, SymbolRefAttr>> accesses, unsigned depth) { // Given a range of Attributes, update the ones referred to by the given // access chains to point to the new symbol attribute. - auto updateAttrs = [&](auto &&attrRange) { - auto attrBegin = std::begin(attrRange); - for (unsigned i = 0, e = accesses.size(); i != e;) { - ArrayRef access = accesses[i].first; - Attribute &attr = *std::next(attrBegin, access[depth]); - - // Check to see if this is a leaf access, i.e. a SymbolRef. - if (access.size() == depth + 1) { - attr = accesses[i].second; - ++i; - continue; - } - // Otherwise, this is a container. Collect all of the accesses for this - // index and recurse. The recursion here is bounded by the size of the - // largest access array. - auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { - ArrayRef nextAccess = it.first; - return nextAccess.size() > depth + 1 && - nextAccess[depth] == access[depth]; - }); - attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); - - // Skip over all of the accesses that refer to the nested container. - i += nestedAccesses.size(); + llvm::SmallVector subElements; + container.walkImmediateSubElements( + [&](Attribute attribute) { subElements.push_back(attribute); }, + [](Type) {}); + for (unsigned i = 0, e = accesses.size(); i != e;) { + ArrayRef access = accesses[i].first; + + // Check to see if this is a leaf access, i.e. a SymbolRef. + if (access.size() == depth + 1) { + container = container.replaceImmediateSubAttribute(access.back(), + accesses[i].second); + ++i; + continue; } - }; - if (auto dictAttr = container.dyn_cast()) { - auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); - updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(dictAttr.getContext(), newAttrs); + // Otherwise, this is a container. Collect all of the accesses for this + // index and recurse. The recursion here is bounded by the size of the + // largest access array. + auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { + ArrayRef nextAccess = it.first; + return nextAccess.size() > depth + 1 && + nextAccess[depth] == access[depth]; + }); + auto result = rebuildAttrAfterRAUW(subElements[access[depth]], + nestedAccesses, depth + 1); + container = container.replaceImmediateSubAttribute(access[depth], result); + + // Skip over all of the accesses that refer to the nested container. + i += nestedAccesses.size(); } - auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); - updateAttrs(newAttrs); - return ArrayAttr::get(container.getContext(), newAttrs); + + return container; } /// Generates a new symbol reference attribute with a new leaf reference. Index: mlir/test/IR/test-symbol-rauw.mlir =================================================================== --- mlir/test/IR/test-symbol-rauw.mlir +++ mlir/test/IR/test-symbol-rauw.mlir @@ -73,3 +73,24 @@ "foo.possibly_unknown_symbol_table"() ({ }) : () -> () } + +// ----- + +// Check that replacement works in any implementations of SubElementsAttrInterface +module { + // CHECK: func private @replaced_foo + func private @symbol_foo() attributes {sym.new_name = "replaced_foo" } + + // CHECK: func @symbol_bar + func @symbol_bar() { + // CHECK: foo.op + // CHECK-SAME: non_symbol_attr, + // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>], + // CHECK-SAME: z_non_symbol_attr_3 + "foo.op"() { + non_symbol_attr, + use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>], + z_non_symbol_attr_3 + } : () -> () + } +} Index: mlir/test/lib/Dialect/Test/TestAttrDefs.td =================================================================== --- mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -16,6 +16,7 @@ // To get the test dialect definition. include "TestOps.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/SubElementInterfaces.td" // All of the attributes will extend this class. class Test_Attr traits = []> @@ -101,4 +102,17 @@ let genVerifyDecl = 1; } +def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ + DeclareAttrInterfaceMethods + ]> { + + let mnemonic = "sub_elements_access"; + + let parameters = (ins + "::mlir::Attribute":$first, + "::mlir::Attribute":$second, + "::mlir::Attribute":$third + ); +} + #endif // TEST_ATTRDEFS Index: mlir/test/lib/Dialect/Test/TestAttributes.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -127,6 +127,49 @@ return success(); } +//===----------------------------------------------------------------------===// +// TestSubElementsAccessAttr +//===----------------------------------------------------------------------===// + +Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser, + ::mlir::Type type) { + Attribute first, second, third; + if (parser.parseLess() || parser.parseAttribute(first) || + parser.parseComma() || parser.parseAttribute(second) || + parser.parseComma() || parser.parseAttribute(third) || + parser.parseGreater()) { + return {}; + } + return get(parser.getContext(), first, second, third); +} + +void TestSubElementsAccessAttr::print( + ::mlir::DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", " + << getThird() << ">"; +} + +void TestSubElementsAccessAttr::walkImmediateSubElements( + llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn) const { + walkAttrsFn(getFirst()); + walkAttrsFn(getSecond()); + walkAttrsFn(getThird()); +} + +SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( + std::size_t index, ::mlir::Attribute value) const { + switch (index) { + case 0: + return get(getContext(), value, getSecond(), getThird()); + case 1: + return get(getContext(), getFirst(), value, getThird()); + case 2: + return get(getContext(), getFirst(), getSecond(), value); + } + LLVM_BUILTIN_UNREACHABLE; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===//