Index: mlir/include/mlir/IR/BuiltinAttributes.td =================================================================== --- mlir/include/mlir/IR/BuiltinAttributes.td +++ mlir/include/mlir/IR/BuiltinAttributes.td @@ -66,7 +66,8 @@ //===----------------------------------------------------------------------===// def Builtin_ArrayAttr : Builtin_Attr<"Array", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "A collection of other Attribute values"; let description = [{ @@ -340,7 +341,8 @@ //===----------------------------------------------------------------------===// def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "An dictionary of named Attribute values"; let description = [{ @@ -949,10 +951,10 @@ symbol nested within a different symbol table. This attribute can only be held internally by - [array attributes](#array-attribute) and + [array attributes](#array-attribute), [dictionary attributes](#dictionary-attribute)(including the top-level - operation attribute dictionary), i.e. no other attribute kinds such as - Locations or extended attribute kinds. + operation attribute dictionary) as well as attributes exposing it via + the `SubElementAttrInterface` interface. **Rationale:** Identifying accesses to global data is critical to enabling efficient multi-threaded compilation. Restricting global Index: mlir/include/mlir/IR/SubElementInterfaces.td =================================================================== --- mlir/include/mlir/IR/SubElementInterfaces.td +++ mlir/include/mlir/IR/SubElementInterfaces.td @@ -33,6 +33,21 @@ (ins "llvm::function_ref":$walkAttrsFn, "llvm::function_ref":$walkTypesFn) >, + InterfaceMethod< + /*desc=*/[{ + Replace the attributes identified by the indices with the corresponding + value. The index is derived from the order of the attributes returned by + the attribute callback of `walkImmediateSubElements`. An index of 0 would + replace the very first attribute given by `walkImmediateSubElements`. + The new instance with the values replaced is returned. + }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute", + (ins "::llvm::ArrayRef>":$replacements), + [{}], + /*defaultImplementation=*/[{ + assert(false && "Attribute or Type does not support replacing attributes"); + return {}; + }] + >, ]; code extraClassDeclaration = [{ Index: mlir/lib/IR/BuiltinAttributes.cpp =================================================================== --- mlir/lib/IR/BuiltinAttributes.cpp +++ mlir/lib/IR/BuiltinAttributes.cpp @@ -53,6 +53,15 @@ walkAttrsFn(attr); } +SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( + ArrayRef> replacements) const { + std::vector vector = getValue().vec(); + for (auto &it : replacements) { + vector[it.first] = it.second; + } + return get(getContext(), vector); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -215,6 +224,17 @@ walkAttrsFn(attr); } +SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( + ArrayRef> replacements) const { + std::vector vec = getValue().vec(); + for (auto &it : replacements) { + vec[it.first].second = it.second; + } + // The above only modifies the mapped value, but not the key, and therefore + // not the order of the elements. It remains sorted + return getWithSorted(getContext(), vec); +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// Index: mlir/lib/IR/SymbolTable.cpp =================================================================== --- mlir/lib/IR/SymbolTable.cpp +++ mlir/lib/IR/SymbolTable.cpp @@ -485,16 +485,30 @@ // A worklist of a container attribute and the current index into the held // attribute list. - SmallVector attrWorklist(1, attrDict); + struct WorklistItem { + SubElementAttrInterface container; + SmallVector immediateSubElements; + + explicit WorklistItem(SubElementAttrInterface container) { + SmallVector subElements; + container.walkImmediateSubElements( + [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {}); + immediateSubElements = std::move(subElements); + } + }; + + SmallVector attrWorklist(1, WorklistItem(attrDict)); SmallVector curAccessChain(1, /*Value=*/-1); // Process the symbol references within the given nested attribute range. - auto processAttrs = [&](int &index, auto attrRange) -> WalkResult { - for (Attribute attr : llvm::drop_begin(attrRange, index)) { + auto processAttrs = [&](int &index, + WorklistItem &worklistItem) -> WalkResult { + for (Attribute attr : + llvm::drop_begin(worklistItem.immediateSubElements, index)) { /// Check for a nested container attribute, these will also need to be /// walked. - if (attr.isa()) { - attrWorklist.push_back(attr); + if (auto interface = attr.dyn_cast()) { + attrWorklist.emplace_back(interface); curAccessChain.push_back(-1); return WalkResult::advance(); } @@ -517,15 +531,12 @@ WalkResult result = WalkResult::advance(); do { - Attribute attr = attrWorklist.back(); + WorklistItem &item = attrWorklist.back(); int &index = curAccessChain.back(); ++index; // Process the given attribute, which is guaranteed to be a container. - if (auto dict = attr.dyn_cast()) - result = processAttrs(index, make_second_range(dict.getValue())); - else - result = processAttrs(index, attr.cast().getValue()); + result = processAttrs(index, item); } while (!attrWorklist.empty() && !result.wasInterrupted()); return result; } @@ -811,48 +822,46 @@ /// Rebuild the given attribute container after replacing all references to a /// symbol with the updated attribute in 'accesses'. -static Attribute rebuildAttrAfterRAUW( - Attribute container, +static SubElementAttrInterface rebuildAttrAfterRAUW( + SubElementAttrInterface container, ArrayRef, SymbolRefAttr>> accesses, unsigned depth) { // Given a range of Attributes, update the ones referred to by the given // access chains to point to the new symbol attribute. - auto updateAttrs = [&](auto &&attrRange) { - auto attrBegin = std::begin(attrRange); - for (unsigned i = 0, e = accesses.size(); i != e;) { - ArrayRef access = accesses[i].first; - Attribute &attr = *std::next(attrBegin, access[depth]); - - // Check to see if this is a leaf access, i.e. a SymbolRef. - if (access.size() == depth + 1) { - attr = accesses[i].second; - ++i; - continue; - } - // Otherwise, this is a container. Collect all of the accesses for this - // index and recurse. The recursion here is bounded by the size of the - // largest access array. - auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { - ArrayRef nextAccess = it.first; - return nextAccess.size() > depth + 1 && - nextAccess[depth] == access[depth]; - }); - attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); - - // Skip over all of the accesses that refer to the nested container. - i += nestedAccesses.size(); + SmallVector> replacements; + + SmallVector subElements; + container.walkImmediateSubElements( + [&](Attribute attribute) { subElements.push_back(attribute); }, + [](Type) {}); + for (unsigned i = 0, e = accesses.size(); i != e;) { + ArrayRef access = accesses[i].first; + + // Check to see if this is a leaf access, i.e. a SymbolRef. + if (access.size() == depth + 1) { + replacements.emplace_back(access.back(), accesses[i].second); + ++i; + continue; } - }; - if (auto dictAttr = container.dyn_cast()) { - auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); - updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(dictAttr.getContext(), newAttrs); + // Otherwise, this is a container. Collect all of the accesses for this + // index and recurse. The recursion here is bounded by the size of the + // largest access array. + auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { + ArrayRef nextAccess = it.first; + return nextAccess.size() > depth + 1 && + nextAccess[depth] == access[depth]; + }); + auto result = rebuildAttrAfterRAUW(subElements[access[depth]], + nestedAccesses, depth + 1); + replacements.emplace_back(access[depth], result); + + // Skip over all of the accesses that refer to the nested container. + i += nestedAccesses.size(); } - auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); - updateAttrs(newAttrs); - return ArrayAttr::get(container.getContext(), newAttrs); + + return container.replaceImmediateSubAttribute(replacements); } /// Generates a new symbol reference attribute with a new leaf reference. Index: mlir/test/IR/test-symbol-rauw.mlir =================================================================== --- mlir/test/IR/test-symbol-rauw.mlir +++ mlir/test/IR/test-symbol-rauw.mlir @@ -73,3 +73,24 @@ "foo.possibly_unknown_symbol_table"() ({ }) : () -> () } + +// ----- + +// Check that replacement works in any implementations of SubElementsAttrInterface +module { + // CHECK: func private @replaced_foo + func private @symbol_foo() attributes {sym.new_name = "replaced_foo" } + + // CHECK: func @symbol_bar + func @symbol_bar() { + // CHECK: foo.op + // CHECK-SAME: non_symbol_attr, + // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>], + // CHECK-SAME: z_non_symbol_attr_3 + "foo.op"() { + non_symbol_attr, + use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>], + z_non_symbol_attr_3 + } : () -> () + } +} Index: mlir/test/lib/Dialect/Test/TestAttrDefs.td =================================================================== --- mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -16,6 +16,7 @@ // To get the test dialect definition. include "TestOps.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/SubElementInterfaces.td" // All of the attributes will extend this class. class Test_Attr traits = []> @@ -101,4 +102,18 @@ let genVerifyDecl = 1; } +def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ + DeclareAttrInterfaceMethods + ]> { + + let mnemonic = "sub_elements_access"; + + let parameters = (ins + "::mlir::Attribute":$first, + "::mlir::Attribute":$second, + "::mlir::Attribute":$third + ); +} + #endif // TEST_ATTRDEFS Index: mlir/test/lib/Dialect/Test/TestAttributes.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -127,6 +127,57 @@ return success(); } +//===----------------------------------------------------------------------===// +// TestSubElementsAccessAttr +//===----------------------------------------------------------------------===// + +Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser, + ::mlir::Type type) { + Attribute first, second, third; + if (parser.parseLess() || parser.parseAttribute(first) || + parser.parseComma() || parser.parseAttribute(second) || + parser.parseComma() || parser.parseAttribute(third) || + parser.parseGreater()) { + return {}; + } + return get(parser.getContext(), first, second, third); +} + +void TestSubElementsAccessAttr::print( + ::mlir::DialectAsmPrinter &printer) const { + printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", " + << getThird() << ">"; +} + +void TestSubElementsAccessAttr::walkImmediateSubElements( + llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn) const { + walkAttrsFn(getFirst()); + walkAttrsFn(getSecond()); + walkAttrsFn(getThird()); +} + +SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( + ArrayRef> replacements) const { + Attribute first = getFirst(); + Attribute second = getSecond(); + Attribute third = getThird(); + for (auto &it : replacements) { + switch (it.first) { + case 0: + first = it.second; + break; + case 1: + second = it.second; + break; + case 2: + third = it.second; + break; + } + } + return get(getContext(), first, second, third); +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===//