diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -116,6 +116,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// @@ -177,6 +179,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// @@ -244,6 +248,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// @@ -375,6 +381,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// @@ -408,7 +416,7 @@ Type getElementType() const; /// Returns the number of elements in the fixed vector. - unsigned getNumElements(); + unsigned getNumElements() const; /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, @@ -416,6 +424,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// @@ -450,7 +460,7 @@ /// Returns the scaling factor of the number of elements in the vector. The /// vector contains at least the resulting number of elements, or any non-zero /// multiple of this number. - unsigned getMinNumElements(); + unsigned getMinNumElements() const; /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, @@ -458,6 +468,8 @@ void walkImmediateSubElements(function_ref walkAttrsFn, function_ref walkTypesFn) const; + Type replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -72,8 +72,7 @@ //===----------------------------------------------------------------------===// def Builtin_ArrayAttr : Builtin_Attr<"Array", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "A collection of other Attribute values"; let description = [{ @@ -425,8 +424,7 @@ //===----------------------------------------------------------------------===// def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ - DeclareAttrInterfaceMethods + DeclareAttrInterfaceMethods ]> { let summary = "An dictionary of named Attribute values"; let description = [{ @@ -1046,7 +1044,9 @@ // SymbolRefAttr //===----------------------------------------------------------------------===// -def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> { +def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [ + DeclareAttrInterfaceMethods + ]> { let summary = "An Attribute containing a symbolic reference to an Operation"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td --- a/mlir/include/mlir/IR/SubElementInterfaces.td +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -21,7 +21,8 @@ // SubElementInterfaceBase //===----------------------------------------------------------------------===// -class SubElementInterfaceBase { +class SubElementInterfaceBase { string cppNamespace = "::mlir"; list methods = [ @@ -35,52 +36,78 @@ >, InterfaceMethod< /*desc=*/[{ - Replace the attributes identified by the indices with the corresponding - value. The index is derived from the order of the attributes returned by - the attribute callback of `walkImmediateSubElements`. An index of 0 would - replace the very first attribute given by `walkImmediateSubElements`. - The new instance with the values replaced is returned. - }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute", - (ins "::llvm::ArrayRef>":$replacements), - [{}], - /*defaultImplementation=*/[{ - llvm_unreachable("Attribute or Type does not support replacing attributes"); - }] - >, + Replace the immediately nested sub-attributes and sub-types with those provided. + The order of the provided elements is derived from the order of the elements + returned by the callbacks of `walkImmediateSubElements`. The element at index 0 + would replace the very first attribute given by `walkImmediateSubElements`. + On success, the new instance with the values replaced is returned. If replacement + fails, nullptr is returned. + }], attrOrType, "replaceImmediateSubElements", (ins + "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs, + "::llvm::ArrayRef<::mlir::Type>":$replTypes + )>, ]; code extraClassDeclaration = [{ - /// Walk all of the held sub-attributes. - void walkSubAttrs(llvm::function_ref walkFn) { - walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); - } - - /// Walk all of the held sub-types. - void walkSubTypes(llvm::function_ref walkFn) { - walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); - } - /// Walk all of the held sub-attributes and sub-types. void walkSubElements(llvm::function_ref walkAttrsFn, llvm::function_ref walkTypesFn); - }]; + /// Recursively replace all of the nested sub-attributes and sub-types using the + /// provided map functions. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceSubElements( + llvm::function_ref replaceAttrFn, + llvm::function_ref replaceTypeFn + ); + }]; code extraTraitClassDeclaration = [{ + /// Walk all of the held sub-attributes and sub-types. + void walkSubElements(llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn) { + }] # interfaceName # " interface(" # derivedValue # [{); + interface.walkSubElements(walkAttrsFn, walkTypesFn); + } + + /// Recursively replace all of the nested sub-attributes and sub-types using the + /// provided map functions. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceSubElements( + llvm::function_ref replaceAttrFn, + llvm::function_ref replaceTypeFn) { + }] # interfaceName # " interface(" # derivedValue # [{); + return interface.replaceSubElements(replaceAttrFn, replaceTypeFn); + } + + /// Recursively replace all of the nested sub-attributes and sub-types using the + /// provided map functions. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceImmediateSubElements( + llvm::ArrayRef replAttrs, + llvm::function_ref replTypes) { + return nullptr; + } + }]; + code extraSharedClassDeclaration = [{ /// Walk all of the held sub-attributes. void walkSubAttrs(llvm::function_ref walkFn) { walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); } - /// Walk all of the held sub-types. void walkSubTypes(llvm::function_ref walkFn) { walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); } - - /// Walk all of the held sub-attributes and sub-types. - void walkSubElements(llvm::function_ref walkAttrsFn, - llvm::function_ref walkTypesFn) { - }] # interfaceName # " interface(" # derivedValue # [{); - interface.walkSubElements(walkAttrsFn, walkTypesFn); + + /// Recursively replace all of the nested sub-attributes using the provided + /// map function. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceSubElements( + llvm::function_ref replaceAttrFn) { + return replaceSubElements( + replaceAttrFn, [](mlir::Type type) { return type; }); + } + /// Recursively replace all of the nested sub-types using the provided map + /// function. Returns nullptr in the case of failure. + }] # attrOrType # [{ replaceSubElements( + llvm::function_ref replaceTypeFn) { + return replaceSubElements( + [](mlir::Attribute attr) { return attr; }, replaceTypeFn); } }]; } @@ -91,7 +118,8 @@ def SubElementAttrInterface : AttrInterface<"SubElementAttrInterface">, - SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> { + SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute", + "$_attr"> { let description = [{ An interface used to query and manipulate sub-elements, such as sub-types and sub-attributes of a composite attribute. @@ -104,7 +132,8 @@ def SubElementTypeInterface : TypeInterface<"SubElementTypeInterface">, - SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> { + SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type", + "$_type"> { let description = [{ An interface used to query and manipulate sub-elements, such as sub-types and sub-attributes of a composite type. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -92,6 +92,11 @@ walkTypesFn(getElementType()); } +Type LLVMArrayType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes.front(), getNumElements()); +} + //===----------------------------------------------------------------------===// // Function type. //===----------------------------------------------------------------------===// @@ -166,6 +171,11 @@ walkTypesFn(type); } +Type LLVMFunctionType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes.front(), replTypes.drop_front(), isVarArg()); +} + //===----------------------------------------------------------------------===// // Pointer type. //===----------------------------------------------------------------------===// @@ -374,6 +384,11 @@ walkTypesFn(getElementType()); } +Type LLVMPointerType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes.front(), getAddressSpace()); +} + //===----------------------------------------------------------------------===// // Struct type. //===----------------------------------------------------------------------===// @@ -617,6 +632,13 @@ walkTypesFn(type); } +Type LLVMStructType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + // TODO: It's not clear how we support replacing sub-elements of mutable + // types. + return nullptr; +} + //===----------------------------------------------------------------------===// // Vector types. //===----------------------------------------------------------------------===// @@ -653,7 +675,7 @@ return static_cast(impl)->elementType; } -unsigned LLVMFixedVectorType::getNumElements() { +unsigned LLVMFixedVectorType::getNumElements() const { return getImpl()->numElements; } @@ -674,6 +696,11 @@ walkTypesFn(getElementType()); } +Type LLVMFixedVectorType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes[0], getNumElements()); +} + //===----------------------------------------------------------------------===// // LLVMScalableVectorType. //===----------------------------------------------------------------------===// @@ -696,7 +723,7 @@ return static_cast(impl)->elementType; } -unsigned LLVMScalableVectorType::getMinNumElements() { +unsigned LLVMScalableVectorType::getMinNumElements() const { return getImpl()->numElements; } @@ -720,6 +747,11 @@ walkTypesFn(getElementType()); } +Type LLVMScalableVectorType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes[0], getMinNumElements()); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -54,13 +54,10 @@ walkAttrsFn(attr); } -SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute( - ArrayRef> replacements) const { - std::vector vector = getValue().vec(); - for (auto &it : replacements) { - vector[it.first] = it.second; - } - return get(getContext(), vector); +Attribute +ArrayAttr::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return get(getContext(), replAttrs); } //===----------------------------------------------------------------------===// @@ -227,11 +224,12 @@ walkAttrsFn(attr.getValue()); } -SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( - ArrayRef> replacements) const { +Attribute +DictionaryAttr::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { std::vector vec = getValue().vec(); - for (auto &it : replacements) - vec[it.first].setValue(it.second); + for (auto &it : llvm::enumerate(replAttrs)) + vec[it.index()].setValue(it.value()); // The above only modifies the mapped value, but not the key, and therefore // not the order of the elements. It remains sorted @@ -326,6 +324,24 @@ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); } +void SymbolRefAttr::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkAttrsFn(getRootReference()); + for (FlatSymbolRefAttr ref : getNestedReferences()) + walkAttrsFn(ref); +} + +Attribute +SymbolRefAttr::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + ArrayRef rawNestedRefs = replAttrs.drop_front(); + ArrayRef nestedRefs( + static_cast(rawNestedRefs.data()), + rawNestedRefs.size()); + return get(replAttrs[0].cast(), nestedRefs); +} + //===----------------------------------------------------------------------===// // IntegerAttr //===----------------------------------------------------------------------===// @@ -1711,3 +1727,9 @@ function_ref walkTypesFn) const { walkTypesFn(getValue()); } + +Attribute +TypeAttr::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return get(replTypes[0]); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -193,6 +193,13 @@ walkTypesFn(type); } +Type FunctionType::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + unsigned numInputs = getNumInputs(); + return get(getContext(), replTypes.take_front(numInputs), + replTypes.drop_front(numInputs)); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -256,6 +263,11 @@ walkTypesFn(getElementType()); } +Type VectorType::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return get(getShape(), replTypes.front(), getNumScalableDims()); +} + VectorType VectorType::cloneWith(Optional> shape, Type elementType) const { return VectorType::get(shape.value_or(getShape()), elementType, @@ -338,6 +350,12 @@ walkAttrsFn(encoding); } +Type RankedTensorType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(getShape(), replTypes.front(), + replAttrs.empty() ? Attribute() : replAttrs.back()); +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -354,6 +372,11 @@ walkTypesFn(getElementType()); } +Type UnrankedTensorType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes.front()); +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -663,6 +686,15 @@ walkAttrsFn(getMemorySpace()); } +Type MemRefType::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + bool hasLayout = replAttrs.size() > 1; + return get(getShape(), replTypes[0], + hasLayout ? replAttrs[0].dyn_cast() + : MemRefLayoutAttrInterface(), + hasLayout ? replAttrs[1] : replAttrs[0]); +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -829,6 +861,11 @@ walkAttrsFn(getMemorySpace()); } +Type UnrankedMemRefType::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + return get(replTypes.front(), replAttrs.front()); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// @@ -859,6 +896,11 @@ walkTypesFn(type); } +Type TupleType::replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return get(getContext(), replTypes); +} + //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -12,6 +12,13 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// SubElementInterface +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// WalkSubElements + template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, @@ -83,6 +90,121 @@ visitedTypes); } +//===----------------------------------------------------------------------===// +// ReplaceSubElements + +/// Return if the given element is mutable. +static bool isMutable(Attribute attr) { + return attr.hasTrait(); +} +static bool isMutable(Type type) { + return type.hasTrait(); +} + +template +static void updateSubElementImpl(T element, function_ref walkFn, + DenseMap &visited, + SmallVectorImpl &newElements, + FailureOr &changed, + ReplaceSubElementFnT &&replaceSubElementFn) { + // Bail early if we failed at any point. + if (failed(changed)) + return; + newElements.push_back(element); + + // Guard against potentially null inputs. We always map null to null. + if (!element) + return; + + // Check for an existing mapping for this element, and walk it if we haven't + // yet. + T &mappedElement = visited[element]; + if (!mappedElement) { + // Try walking this element. + if (!(mappedElement = walkFn(element))) { + changed = failure(); + return; + } + + // Handle replacing sub-elements if this element is also a container. + if (auto interface = mappedElement.template dyn_cast()) { + if (!(mappedElement = replaceSubElementFn(interface))) { + changed = failure(); + return; + } + } + } + + // Update to the mapped element. + if (mappedElement != element) { + newElements.back() = mappedElement; + changed = true; + } +} + +template +static typename InterfaceT::ValueType +replaceSubElementsImpl(InterfaceT interface, + function_ref walkAttrsFn, + function_ref walkTypesFn, + DenseMap &visitedAttrs, + DenseMap &visitedTypes) { + // Walk the current sub-elements, replacing them as necessary. + SmallVector newAttrs; + SmallVector newTypes; + FailureOr changed = false; + auto replaceSubElementFn = [&](auto subInterface) { + return replaceSubElementsImpl(subInterface, walkAttrsFn, walkTypesFn, + visitedAttrs, visitedTypes); + }; + interface.walkImmediateSubElements( + [&](Attribute element) { + updateSubElementImpl( + element, walkAttrsFn, visitedAttrs, newAttrs, changed, + replaceSubElementFn); + }, + [&](Type element) { + updateSubElementImpl( + element, walkTypesFn, visitedTypes, newTypes, changed, + replaceSubElementFn); + }); + if (failed(changed)) + return {}; + + // If the sub-elements didn't change, just return the original value. + if (!*changed) + return interface; + + // If this element is mutable, we don't support changing its sub elements, the + // sub element walk doesn't give us a valid ordering for what we need here. If + // we want to support mutable elements, we'll need something more. + if (isMutable(interface)) + return {}; + + // Use the new elements during the replacement. + return interface.replaceImmediateSubElements(newAttrs, newTypes); +} + +Attribute SubElementAttrInterface::replaceSubElements( + function_ref replaceAttrFn, + function_ref replaceTypeFn) { + assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); + DenseMap visitedAttrs; + DenseMap visitedTypes; + return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn, + visitedAttrs, visitedTypes); +} + +Type SubElementTypeInterface::replaceSubElements( + function_ref replaceAttrFn, + function_ref replaceTypeFn) { + assert(replaceAttrFn && replaceTypeFn && "expected valid replace functions"); + DenseMap visitedAttrs; + DenseMap visitedTypes; + return replaceSubElementsImpl(*this, replaceAttrFn, replaceTypeFn, + visitedAttrs, visitedTypes); +} + //===----------------------------------------------------------------------===// // SubElementInterface Tablegen definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -97,6 +97,18 @@ return WalkResult::advance(); } +/// Walk all of the operations nested under, and including, the given operation, +/// without traversing into any nested symbol tables. Stops walking if the +/// result of the callback is anything other than `WalkResult::advance`. +static Optional +walkSymbolTable(Operation *op, + function_ref(Operation *)> callback) { + Optional result = callback(op); + if (result != WalkResult::advance() || op->hasTrait()) + return result; + return walkSymbolTable(op->getRegions(), callback); +} + //===----------------------------------------------------------------------===// // SymbolTable //===----------------------------------------------------------------------===// @@ -465,21 +477,11 @@ //===----------------------------------------------------------------------===// /// Walk all of the symbol references within the given operation, invoking the -/// provided callback for each found use. The callbacks takes as arguments: the -/// use of the symbol, and the nested access chain to the attribute within the -/// operation dictionary. An access chain is a set of indices into nested -/// container attributes. For example, a symbol use in an attribute dictionary -/// that looks like the following: -/// -/// {use = [{other_attr, @symbol}]} -/// -/// May have the following access chain: -/// -/// [0, 0, 1] -/// -static WalkResult walkSymbolRefs( - Operation *op, - function_ref)> callback) { +/// provided callback for each found use. The callbacks takes the use of the +/// symbol. +static WalkResult +walkSymbolRefs(Operation *op, + function_ref callback) { // Check to see if the operation has any attributes. DictionaryAttr attrDict = op->getAttrDictionary(); if (attrDict.empty()) @@ -507,20 +509,19 @@ WorklistItem &worklistItem) -> WalkResult { for (Attribute attr : llvm::drop_begin(worklistItem.immediateSubElements, index)) { - /// Check for a nested container attribute, these will also need to be - /// walked. - if (auto interface = attr.dyn_cast()) { - attrWorklist.emplace_back(interface); - curAccessChain.push_back(-1); - return WalkResult::advance(); - } - // Invoke the provided callback if we find a symbol use and check for a // requested interrupt. - if (auto symbolRef = attr.dyn_cast()) - if (callback({op, symbolRef}, curAccessChain).wasInterrupted()) + if (auto symbolRef = attr.dyn_cast()) { + if (callback({op, symbolRef}).wasInterrupted()) return WalkResult::interrupt(); + /// Check for a nested container attribute, these will also need to be + /// walked. + } else if (auto interface = attr.dyn_cast()) { + attrWorklist.emplace_back(interface); + curAccessChain.push_back(-1); + return WalkResult::advance(); + } // Make sure to keep the index counter in sync. ++index; } @@ -546,9 +547,9 @@ /// Walk all of the uses, for any symbol, that are nested within the given /// regions, invoking the provided callback for each. This does not traverse /// into any nested symbol tables. -static Optional walkSymbolUses( - MutableArrayRef regions, - function_ref)> callback) { +static Optional +walkSymbolUses(MutableArrayRef regions, + function_ref callback) { return walkSymbolTable(regions, [&](Operation *op) -> Optional { // Check that this isn't a potentially unknown symbol table. if (isPotentiallyUnknownSymbolTable(op)) @@ -560,9 +561,9 @@ /// Walk all of the uses, for any symbol, that are nested within the given /// operation 'from', invoking the provided callback for each. This does not /// traverse into any nested symbol tables. -static Optional walkSymbolUses( - Operation *from, - function_ref)> callback) { +static Optional +walkSymbolUses(Operation *from, + function_ref callback) { // If this operation has regions, and it, as well as its dialect, isn't // registered then conservatively fail. The operation may define a // symbol table, so we can't opaquely know if we should traverse to find @@ -608,11 +609,20 @@ typename llvm::function_traits::result_t, void>::value> * = nullptr> Optional walk(CallbackT cback) { - return walk([=](SymbolTable::SymbolUse use, ArrayRef) { + return walk([=](SymbolTable::SymbolUse use) { return cback(use), WalkResult::advance(); }); } + /// Walk all of the operations nested under the current scope without + /// traversing into any nested symbol tables. + template + Optional walkSymbolTable(CallbackT &&cback) { + if (Region *region = limit.dyn_cast()) + return ::walkSymbolTable(*region, cback); + return ::walkSymbolTable(limit.get(), cback); + } + /// The representation of the symbol within this scope. SymbolRefAttr symbol; @@ -723,7 +733,7 @@ template static Optional getSymbolUsesImpl(FromT from) { std::vector uses; - auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + auto walkFn = [&](SymbolTable::SymbolUse symbolUse) { uses.push_back(symbolUse); return WalkResult::advance(); }; @@ -792,7 +802,7 @@ static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) { for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { // Walk all of the symbol uses looking for a reference to 'symbol'. - if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + if (scope.walk([&](SymbolTable::SymbolUse symbolUse) { return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt() : WalkResult::advance(); @@ -822,50 +832,6 @@ //===----------------------------------------------------------------------===// // SymbolTable::replaceAllSymbolUses -/// Rebuild the given attribute container after replacing all references to a -/// symbol with the updated attribute in 'accesses'. -static SubElementAttrInterface rebuildAttrAfterRAUW( - SubElementAttrInterface container, - ArrayRef, SymbolRefAttr>> accesses, - unsigned depth) { - // Given a range of Attributes, update the ones referred to by the given - // access chains to point to the new symbol attribute. - - SmallVector> replacements; - - SmallVector subElements; - container.walkImmediateSubElements( - [&](Attribute attribute) { subElements.push_back(attribute); }, - [](Type) {}); - for (unsigned i = 0, e = accesses.size(); i != e;) { - ArrayRef access = accesses[i].first; - - // Check to see if this is a leaf access, i.e. a SymbolRef. - if (access.size() == depth + 1) { - replacements.emplace_back(access.back(), accesses[i].second); - ++i; - continue; - } - - // Otherwise, this is a container. Collect all of the accesses for this - // index and recurse. The recursion here is bounded by the size of the - // largest access array. - auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { - ArrayRef nextAccess = it.first; - return nextAccess.size() > depth + 1 && - nextAccess[depth] == access[depth]; - }); - auto result = rebuildAttrAfterRAUW(subElements[access[depth]], - nestedAccesses, depth + 1); - replacements.emplace_back(access[depth], result); - - // Skip over all of the accesses that refer to the nested container. - i += nestedAccesses.size(); - } - - return container.replaceImmediateSubAttribute(replacements); -} - /// Generates a new symbol reference attribute with a new leaf reference. static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, FlatSymbolRefAttr newLeafAttr) { @@ -880,77 +846,43 @@ template static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { - // A collection of operations along with their new attribute dictionary. - std::vector> updatedAttrDicts; - - // The current operation being processed. - Operation *curOp = nullptr; - - // The set of access chains into the attribute dictionary of the current - // operation, as well as the replacement attribute to use. - SmallVector, SymbolRefAttr>, 1> accessChains; - - // Generate a new attribute dictionary for the current operation by replacing - // references to the old symbol. - auto generateNewAttrDict = [&] { - auto oldDict = curOp->getAttrDictionary(); - auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); - return newDict.cast(); - }; - // Generate a new attribute to replace the given attribute. FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { + SymbolRefAttr oldAttr = scope.symbol; SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); - auto walkFn = [&](SymbolTable::SymbolUse symbolUse, - ArrayRef accessChain) { - SymbolRefAttr useRef = symbolUse.getSymbolRef(); - if (!isReferencePrefixOf(scope.symbol, useRef)) - return WalkResult::advance(); - // If we have a valid match, check to see if this is a proper - // subreference. If it is, then we will need to generate a different new - // attribute specifically for this use. - SymbolRefAttr replacementRef = newAttr; - if (useRef != scope.symbol) { - if (scope.symbol.isa()) { - replacementRef = - SymbolRefAttr::get(newSymbol, useRef.getNestedReferences()); - } else { - auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); - nestedRefs[scope.symbol.getNestedReferences().size() - 1] = - newLeafAttr; - replacementRef = - SymbolRefAttr::get(useRef.getRootReference(), nestedRefs); + auto walkFn = [&](Operation *op) -> Optional { + auto remapAttrFn = [&](Attribute attr) -> Attribute { + if (attr == oldAttr) + return newAttr; + // Handle prefix matches. + if (SymbolRefAttr symRef = attr.dyn_cast()) { + if (isReferencePrefixOf(oldAttr, symRef)) { + auto oldNestedRefs = oldAttr.getNestedReferences(); + auto nestedRefs = symRef.getNestedReferences(); + if (oldNestedRefs.empty()) + return SymbolRefAttr::get(newSymbol, nestedRefs); + + auto newNestedRefs = llvm::to_vector<4>(nestedRefs); + newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr; + return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs); + } } - } - - // If there was a previous operation, generate a new attribute dict - // for it. This means that we've finished processing the current - // operation, so generate a new dictionary for it. - if (curOp && symbolUse.getUser() != curOp) { - updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); - accessChains.clear(); - } - - // Record this access. - curOp = symbolUse.getUser(); - accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); + return attr; + }; + // Generate a new attribute dictionary by replacing references to the old + // symbol. + auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn); + if (!newDict) + return WalkResult::interrupt(); + + op->setAttrs(newDict.template cast()); return WalkResult::advance(); }; - if (!scope.walk(walkFn)) + if (!scope.walkSymbolTable(walkFn)) return failure(); - - // Check to see if we have a dangling op that needs to be processed. - if (curOp) { - updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); - curOp = nullptr; - } } - - // Update the attribute dictionaries as necessary. - for (auto &it : updatedAttrDicts) - it.first->setAttrs(it.second); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -114,9 +114,8 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ DeclareAttrInterfaceMethods + ["replaceImmediateSubElements"]> ]> { - let mnemonic = "sub_elements_access"; let parameters = (ins diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -173,25 +173,10 @@ walkAttrsFn(getThird()); } -SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute( - ArrayRef> replacements) const { - Attribute first = getFirst(); - Attribute second = getSecond(); - Attribute third = getThird(); - for (auto &it : replacements) { - switch (it.first) { - case 0: - first = it.second; - break; - case 1: - second = it.second; - break; - case 2: - third = it.second; - break; - } - } - return get(getContext(), first, second, third); +Attribute TestSubElementsAccessAttr::replaceImmediateSubElements( + ArrayRef replAttrs, ArrayRef replTypes) const { + assert(replAttrs.size() == 3 && "invalid number of replacement attributes"); + return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -154,6 +154,12 @@ ::llvm::function_ref walkTypesFn) const { walkTypesFn(getBody()); } + Type replaceImmediateSubElements(llvm::ArrayRef replAttrs, + llvm::ArrayRef replTypes) const { + // TODO: It's not clear how we support replacing sub-elements of mutable + // types. + return nullptr; + } }; } // namespace test