diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -12,7 +12,6 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Dialect/LLVMIR/LLVMEnums.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" -include "mlir/IR/SubElementInterfaces.td" // All of the attributes will extend this class. class LLVM_Attr { +def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", + /*traits=*/[], "DIScopeAttr"> { let parameters = (ins LLVM_DILanguageParameter:$sourceLanguage, "DIFileAttr":$file, @@ -177,9 +175,8 @@ // DICompositeTypeAttr //===----------------------------------------------------------------------===// -def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", [ - SubElementAttrInterface - ], "DITypeAttr"> { +def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", + /*traits=*/[], "DITypeAttr"> { let parameters = (ins LLVM_DITagParameter:$tag, "StringAttr":$name, @@ -199,9 +196,8 @@ // DIDerivedTypeAttr //===----------------------------------------------------------------------===// -def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", [ - SubElementAttrInterface - ], "DITypeAttr"> { +def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", + /*traits=*/[], "DITypeAttr"> { let parameters = (ins LLVM_DITagParameter:$tag, OptionalParameter<"StringAttr">:$name, @@ -231,9 +227,8 @@ // DILexicalBlockAttr //===----------------------------------------------------------------------===// -def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [ - SubElementAttrInterface - ], "DIScopeAttr"> { +def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", + /*traits=*/[], "DIScopeAttr"> { let parameters = (ins "DIScopeAttr":$scope, OptionalParameter<"DIFileAttr">:$file, @@ -255,9 +250,8 @@ // DILexicalBlockFileAttr //===----------------------------------------------------------------------===// -def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", [ - SubElementAttrInterface - ], "DIScopeAttr"> { +def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", + /*traits=*/[], "DIScopeAttr"> { let parameters = (ins "DIScopeAttr":$scope, OptionalParameter<"DIFileAttr">:$file, @@ -277,9 +271,8 @@ // DILocalVariableAttr //===----------------------------------------------------------------------===// -def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", [ - SubElementAttrInterface - ], "DINodeAttr"> { +def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", + /*traits=*/[], "DINodeAttr"> { let parameters = (ins "DIScopeAttr":$scope, "StringAttr":$name, @@ -307,9 +300,8 @@ // DISubprogramAttr //===----------------------------------------------------------------------===// -def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", [ - SubElementAttrInterface - ], "DIScopeAttr"> { +def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", + /*traits=*/[], "DIScopeAttr"> { let parameters = (ins "DICompileUnitAttr":$compileUnit, "DIScopeAttr":$scope, @@ -357,9 +349,8 @@ // DISubroutineTypeAttr //===----------------------------------------------------------------------===// -def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", [ - SubElementAttrInterface - ], "DITypeAttr"> { +def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", + /*traits=*/[], "DITypeAttr"> { let parameters = (ins LLVM_DICallingConventionParameter:$callingConvention, OptionalArrayRefParameter<"DITypeAttr">:$types @@ -377,9 +368,7 @@ // MemoryEffectsAttr //===----------------------------------------------------------------------===// -def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects", [ - SubElementAttrInterface - ]> { +def LLVM_MemoryEffectsAttr : LLVM_Attr<"MemoryEffects", "memory_effects"> { let parameters = (ins "ModRefInfo":$other, "ModRefInfo":$argMem, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -14,7 +14,6 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ -#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include @@ -104,7 +103,6 @@ class LLVMStructType : public Type::TypeBase { public: /// Inherit base constructors. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -11,7 +11,6 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/SubElementInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" /// Base class for all LLVM dialect types. @@ -25,8 +24,7 @@ //===----------------------------------------------------------------------===// def LLVMArrayType : LLVMType<"LLVMArray", "array", [ - DeclareTypeInterfaceMethods, - DeclareTypeInterfaceMethods]> { + DeclareTypeInterfaceMethods]> { let summary = "LLVM array type"; let description = [{ The `!llvm.array` type represents a fixed-size array of element types. @@ -62,8 +60,7 @@ // LLVMFunctionType //===----------------------------------------------------------------------===// -def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [ - DeclareTypeInterfaceMethods]> { +def LLVMFunctionType : LLVMType<"LLVMFunction", "func"> { let summary = "LLVM function type"; let description = [{ The `!llvm.func` is a function type. It consists of a single return type @@ -124,8 +121,7 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [ DeclareTypeInterfaceMethods, - DeclareTypeInterfaceMethods]> { + "areCompatible", "verifyEntries"]>]> { let summary = "LLVM pointer type"; let description = [{ The `!llvm.ptr` type is an LLVM pointer type. This type typically represents @@ -171,8 +167,7 @@ // LLVMFixedVectorType //===----------------------------------------------------------------------===// -def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [ - DeclareTypeInterfaceMethods]> { +def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> { let summary = "LLVM fixed vector type"; let description = [{ LLVM dialect scalable vector type, represents a sequence of elements of @@ -202,8 +197,7 @@ // LLVMScalableVectorType //===----------------------------------------------------------------------===// -def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [ - DeclareTypeInterfaceMethods]> { +def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> { let summary = "LLVM scalable vector type"; let description = [{ LLVM dialect scalable vector type, represents a sequence of elements of diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/AttrTypeSubElements.h rename from mlir/include/mlir/IR/SubElementInterfaces.h rename to mlir/include/mlir/IR/AttrTypeSubElements.h --- a/mlir/include/mlir/IR/SubElementInterfaces.h +++ b/mlir/include/mlir/IR/AttrTypeSubElements.h @@ -1,4 +1,4 @@ -//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===// +//===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,20 +6,112 @@ // //===----------------------------------------------------------------------===// // -// This file contains interfaces and utilities for querying the sub elements of -// an attribute or type. +// This file contains utilities for querying the sub elements of an attribute or +// type. // //===----------------------------------------------------------------------===// -#ifndef MLIR_IR_SUBELEMENTINTERFACES_H -#define MLIR_IR_SUBELEMENTINTERFACES_H +#ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H +#define MLIR_IR_ATTRTYPESUBELEMENTS_H -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include namespace mlir { +class Attribute; +class Type; + +//===----------------------------------------------------------------------===// +/// AttrTypeWalker +//===----------------------------------------------------------------------===// + +/// This class provides a utility for walking attributes/types, and their sub +/// elements. Multiple walk functions may be registered. +class AttrTypeWalker { +public: + //===--------------------------------------------------------------------===// + // Application + //===--------------------------------------------------------------------===// + + /// Walk the given attribute/type, and recursively walk any sub elements. + template + WalkResult walk(T element) { + return walkImpl(element, Order); + } + template + WalkResult walk(T element) { + return walk(element); + } + + //===--------------------------------------------------------------------===// + // Registration + //===--------------------------------------------------------------------===// + + template + using WalkFn = std::function; + + /// Register a walk function for a given attribute or type. A walk function + /// must be convertible to any of the following forms(where `T` is a class + /// derived from `Type` or `Attribute`: + /// + /// * WalkResult(T) + /// - Returns a walk result, which can be used to control the walk + /// + /// * void(T) + /// - Returns void, i.e. the walk always continues. + /// + /// Note: When walking, the mostly recently added walk functions will be + /// invoked first. + void addWalk(WalkFn &&fn) { + attrWalkFns.emplace_back(std::move(fn)); + } + void addWalk(WalkFn &&fn) { typeWalkFns.push_back(std::move(fn)); } + + /// Register a replacement function that doesn't match the default signature, + /// either because it uses a derived parameter type, or it uses a simplified + /// result type. + template >::template arg_t<0>, + typename BaseT = std::conditional_t, + Attribute, Type>, + typename ResultT = std::invoke_result_t> + std::enable_if_t || std::is_same_v> + addWalk(FnT &&callback) { + addWalk([callback = std::forward(callback)](BaseT base) -> WalkResult { + if (auto derived = dyn_cast(base)) { + if constexpr (std::is_convertible_v) + return callback(derived); + else + callback(derived); + } + return WalkResult::advance(); + }); + } + +private: + WalkResult walkImpl(Attribute attr, WalkOrder order); + WalkResult walkImpl(Type type, WalkOrder order); + + /// Internal implementation of the `walk` methods above. + template + WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order); + + /// Walk the sub elements of the given interface. + template + WalkResult walkSubElements(T interface, WalkOrder order); + + /// The set of walk functions that map sub elements. + std::vector> attrWalkFns; + std::vector> typeWalkFns; + + /// The set of visited attributes/types. + DenseMap, WalkResult> visitedAttrTypes; +}; + //===----------------------------------------------------------------------===// /// AttrTypeReplacer //===----------------------------------------------------------------------===// @@ -84,12 +176,8 @@ /// /// Note: When replacing, the mostly recently added replacement functions will /// be invoked first. - void addReplacement(ReplaceFn fn) { - attrReplacementFns.emplace_back(std::move(fn)); - } - void addReplacement(ReplaceFn fn) { - typeReplacementFns.push_back(std::move(fn)); - } + void addReplacement(ReplaceFn fn); + void addReplacement(ReplaceFn fn); /// Register a replacement function that doesn't match the default signature, /// either because it uses a derived parameter type, or it uses a simplified @@ -120,20 +208,19 @@ private: /// Internal implementation of the `replace` methods above. - template - T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap &map); + template + T replaceImpl(T element, ReplaceFns &replaceFns); /// Replace the sub elements of the given interface. - template - T replaceSubElements(InterfaceT interface, DenseMap &interfaceMap); + template + T replaceSubElements(T interface); /// The set of replacement functions that map sub elements. std::vector> attrReplacementFns; std::vector> typeReplacementFns; /// The set of cached mappings for attributes/types. - DenseMap attrMap; - DenseMap typeMap; + DenseMap attrTypeMap; }; //===----------------------------------------------------------------------===// @@ -142,22 +229,16 @@ /// This class is used by AttrTypeSubElementHandler instances to walking sub /// attributes and types. -class AttrTypeSubElementWalker { +class AttrTypeImmediateSubElementWalker { public: - AttrTypeSubElementWalker(function_ref walkAttrsFn, - function_ref walkTypesFn) + AttrTypeImmediateSubElementWalker(function_ref walkAttrsFn, + function_ref walkTypesFn) : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {} /// Walk an attribute. - void walk(Attribute element) { - if (element) - walkAttrsFn(element); - } + void walk(Attribute element); /// Walk a type. - void walk(Type element) { - if (element) - walkTypesFn(element); - } + void walk(Type element); /// Walk a range of attributes or types. template void walkRange(RangeT &&elements) { @@ -212,7 +293,8 @@ template struct AttrTypeSubElementHandler { /// Default walk implementation that does nothing. - static inline void walk(const T ¶m, AttrTypeSubElementWalker &walker) {} + static inline void walk(const T ¶m, + AttrTypeImmediateSubElementWalker &walker) {} /// Default replace implementation just forwards the parameter. template @@ -241,7 +323,7 @@ struct AttrTypeSubElementHandler< T, std::enable_if_t || std::is_base_of_v>> { - static void walk(T param, AttrTypeSubElementWalker &walker) { + static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { walker.walk(param); } static T replace(T param, AttrSubElementReplacements &attrRepls, @@ -255,27 +337,14 @@ } } }; -template <> -struct AttrTypeSubElementHandler { - template - static void walk(T param, AttrTypeSubElementWalker &walker) { - walker.walk(param.getName()); - walker.walk(param.getValue()); - } - template - static T replace(T param, AttrSubElementReplacements &attrRepls, - TypeSubElementReplacements &typeRepls) { - ArrayRef paramRepls = attrRepls.take_front(2); - return T(cast(paramRepls[0]), paramRepls[1]); - } -}; /// Implementation for derived ArrayRef. template struct AttrTypeSubElementHandler, std::enable_if_t>> { using EltHandler = AttrTypeSubElementHandler; - static void walk(ArrayRef param, AttrTypeSubElementWalker &walker) { + static void walk(ArrayRef param, + AttrTypeImmediateSubElementWalker &walker) { for (const T &subElement : param) EltHandler::walk(subElement, walker); } @@ -283,11 +352,11 @@ TypeSubElementReplacements &typeRepls) { // Normal attributes/types can extract using the replacer directly. if constexpr (std::is_base_of_v && - sizeof(T) == sizeof(Attribute)) { + sizeof(T) == sizeof(void *)) { ArrayRef attrs = attrRepls.take_front(param.size()); return ArrayRef((const T *)attrs.data(), attrs.size()); } else if constexpr (std::is_base_of_v && - sizeof(T) == sizeof(Type)) { + sizeof(T) == sizeof(void *)) { ArrayRef types = typeRepls.take_front(param.size()); return ArrayRef((const T *)types.data(), types.size()); } else { @@ -305,7 +374,7 @@ struct AttrTypeSubElementHandler< std::tuple, std::enable_if_t>> { static void walk(const std::tuple ¶m, - AttrTypeSubElementWalker &walker) { + AttrTypeImmediateSubElementWalker &walker) { std::apply( [&](const Ts &...params) { (AttrTypeSubElementHandler::walk(params, walker), ...); @@ -333,6 +402,8 @@ struct is_tuple> : public std::true_type {}; template using has_get_method = decltype(T::get(std::declval()...)); +template +using has_get_as_key = decltype(std::declval().getAsKey()); /// This function provides the underlying implementation for the /// SubElementInterface walk method, using the key type of the derived @@ -341,21 +412,24 @@ void walkImmediateSubElementsImpl(T derived, function_ref walkAttrsFn, function_ref walkTypesFn) { - auto key = static_cast(derived.getImpl())->getAsKey(); + using ImplT = typename T::ImplType; + if constexpr (llvm::is_detected::value) { + auto key = static_cast(derived.getImpl())->getAsKey(); - // If we don't have any sub-elements, there is nothing to do. - if constexpr (!has_sub_attr_or_type_v) { - return; - } else { - AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn); - AttrTypeSubElementHandler::walk(key, walker); + // If we don't have any sub-elements, there is nothing to do. + if constexpr (!has_sub_attr_or_type_v) { + return; + } else { + AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn); + AttrTypeSubElementHandler::walk(key, walker); + } } } /// This function invokes the proper `get` method for a type `T` with the given /// values. template -T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) { +auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) { // Prefer a direct `get` method if one exists. if constexpr (llvm::is_detected::value) { (void)ctx; @@ -373,38 +447,39 @@ /// SubElementInterface replace method, using the key type of the derived /// attribute/type to interact with the individual parameters. template -T replaceImmediateSubElementsImpl(T derived, ArrayRef &replAttrs, - ArrayRef &replTypes) { - auto key = static_cast(derived.getImpl())->getAsKey(); +auto replaceImmediateSubElementsImpl(T derived, ArrayRef &replAttrs, + ArrayRef &replTypes) { + using ImplT = typename T::ImplType; + if constexpr (llvm::is_detected::value) { + auto key = static_cast(derived.getImpl())->getAsKey(); - // If we don't have any sub-elements, we can just return the original. - if constexpr (!has_sub_attr_or_type_v) { - return derived; + // If we don't have any sub-elements, we can just return the original. + if constexpr (!has_sub_attr_or_type_v) { + return derived; - // Otherwise, we need to replace any necessary sub-elements. - } else { - AttrSubElementReplacements attrRepls(replAttrs); - TypeSubElementReplacements typeRepls(replTypes); - auto newKey = AttrTypeSubElementHandler::replace( - key, attrRepls, typeRepls); - if constexpr (is_tuple::value) { - return std::apply( - [&](auto &&...params) { - return constructSubElementReplacement( - derived.getContext(), - std::forward(params)...); - }, - newKey); + // Otherwise, we need to replace any necessary sub-elements. } else { - return constructSubElementReplacement(derived.getContext(), newKey); + AttrSubElementReplacements attrRepls(replAttrs); + TypeSubElementReplacements typeRepls(replTypes); + auto newKey = AttrTypeSubElementHandler::replace( + key, attrRepls, typeRepls); + if constexpr (is_tuple::value) { + return std::apply( + [&](auto &&...params) { + return constructSubElementReplacement( + derived.getContext(), + std::forward(params)...); + }, + newKey); + } else { + return constructSubElementReplacement(derived.getContext(), newKey); + } } + } else { + return derived; } } } // namespace detail } // namespace mlir -/// Include the definitions of the sub element interfaces. -#include "mlir/IR/SubElementAttrInterfaces.h.inc" -#include "mlir/IR/SubElementTypeInterfaces.h.inc" - -#endif // MLIR_IR_SUBELEMENTINTERFACES_H +#endif // MLIR_IR_ATTRTYPESUBELEMENTS_H diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -20,9 +20,6 @@ #include "llvm/ADT/Twine.h" namespace mlir { -class MLIRContext; -class Type; - //===----------------------------------------------------------------------===// // AbstractAttribute //===----------------------------------------------------------------------===// @@ -32,6 +29,10 @@ class AbstractAttribute { public: using HasTraitFn = llvm::unique_function; + using WalkImmediateSubElementsFn = function_ref, function_ref)>; + using ReplaceImmediateSubElementsFn = + function_ref, ArrayRef)>; /// Look up the specified abstract attribute in the MLIRContext and return a /// reference to it. @@ -42,6 +43,8 @@ template static AbstractAttribute get(Dialect &dialect) { return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(), + T::getWalkImmediateSubElementsFn(), + T::getReplaceImmediateSubElementsFn(), T::getTypeID()); } @@ -49,11 +52,15 @@ /// custom TypeIDs. /// The use of this method is in general discouraged in favor of /// 'get(dialect)'. - static AbstractAttribute get(Dialect &dialect, - detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) { + static AbstractAttribute + get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, + HasTraitFn &&hasTrait, + WalkImmediateSubElementsFn walkImmediateSubElementsFn, + ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, + TypeID typeID) { return AbstractAttribute(dialect, std::move(interfaceMap), - std::move(hasTrait), typeID); + std::move(hasTrait), walkImmediateSubElementsFn, + replaceImmediateSubElementsFn, typeID); } /// Return the dialect this attribute was registered to. @@ -82,14 +89,30 @@ /// Returns true if the attribute has a particular trait. bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } + /// Walk the immediate sub-elements of this attribute. + void walkImmediateSubElements(Attribute attr, + function_ref walkAttrsFn, + function_ref walkTypesFn) const; + + /// Replace the immediate sub-elements of this attribute. + Attribute replaceImmediateSubElements(Attribute attr, + ArrayRef replAttrs, + ArrayRef replTypes) const; + /// Return the unique identifier representing the concrete attribute class. TypeID getTypeID() const { return typeID; } private: AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) + HasTraitFn &&hasTraitFn, + WalkImmediateSubElementsFn walkImmediateSubElementsFn, + ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, + TypeID typeID) : dialect(dialect), interfaceMap(std::move(interfaceMap)), - hasTraitFn(std::move(hasTrait)), typeID(typeID) {} + hasTraitFn(std::move(hasTraitFn)), + walkImmediateSubElementsFn(walkImmediateSubElementsFn), + replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), + typeID(typeID) {} /// Give StorageUserBase access to the mutable lookup. template getAbstractAttribute(); } + /// Walk all of the immediately nested sub-attributes and sub-types. This + /// method does not recurse into sub elements. + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const { + getAbstractAttribute().walkImmediateSubElements(*this, walkAttrsFn, + walkTypesFn); + } + + /// Replace the immediately nested sub-attributes and sub-types with those + /// provided. The order of the provided elements is derived from the order of + /// the elements returned by the callbacks of `walkImmediateSubElements`. The + /// element at index 0 would replace the very first attribute given by + /// `walkImmediateSubElements`. On success, the new instance with the values + /// replaced is returned. If replacement fails, nullptr is returned. + auto replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return getAbstractAttribute().replaceImmediateSubElements(*this, replAttrs, + replTypes); + } + + /// Walk this attribute and all attibutes/types nested within using the + /// provided walk functions. See `AttrTypeWalker` for information on the + /// supported walk function types. + template + auto walk(WalkFns &&...walkFns) { + AttrTypeWalker walker; + (walker.addWalk(std::forward(walkFns)), ...); + return walker.walk(*this); + } + + /// Recursively replace all of the nested sub-attributes and sub-types using + /// the provided map functions. Returns nullptr in the case of failure. See + /// `AttrTypeReplacer` for information on the support replacement function + /// types. + template + auto replace(ReplacementFns &&...replacementFns) { + AttrTypeReplacer replacer; + (replacer.addReplacement(std::forward(replacementFns)), + ...); + return replacer.replace(*this); + } + /// Return the internal Attribute implementation. ImplType *getImpl() const { return impl; } @@ -201,6 +243,22 @@ return DenseMapInfo::getHashValue(AttrPairT(arg.name, arg.value)); } +/// Allow walking and replacing the subelements of a NamedAttribute. +template <> +struct AttrTypeSubElementHandler { + template + static void walk(T param, AttrTypeImmediateSubElementWalker &walker) { + walker.walk(param.getName()); + walker.walk(param.getValue()); + } + template + static T replace(T param, AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + ArrayRef paramRepls = attrRepls.take_front(2); + return T(cast(paramRepls[0]), paramRepls[1]); + } +}; + //===----------------------------------------------------------------------===// // AttributeTraitBase //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -10,7 +10,6 @@ #define MLIR_IR_BUILTINATTRIBUTES_H #include "mlir/IR/BuiltinAttributeInterfaces.h" -#include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -18,7 +18,6 @@ include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with // `Builtin_`. This is to differentiate the attributes here with the ones in @@ -71,9 +70,7 @@ // ArrayAttr //===----------------------------------------------------------------------===// -def Builtin_ArrayAttr : Builtin_Attr<"Array", [ - SubElementAttrInterface - ]> { +def Builtin_ArrayAttr : Builtin_Attr<"Array"> { let summary = "A collection of other Attribute values"; let description = [{ Syntax: @@ -491,9 +488,7 @@ // DictionaryAttr //===----------------------------------------------------------------------===// -def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ - SubElementAttrInterface - ]> { +def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { let summary = "An dictionary of named Attribute values"; let description = [{ Syntax: @@ -1096,9 +1091,7 @@ // SymbolRefAttr //===----------------------------------------------------------------------===// -def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [ - SubElementAttrInterface - ]> { +def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> { let summary = "An Attribute containing a symbolic reference to an Operation"; let description = [{ Syntax: @@ -1114,13 +1107,6 @@ may optionally contain a set of nested references that further resolve to a symbol nested within a different symbol table. - This attribute can only be held internally by - [array attributes](#array-attribute), - [dictionary attributes](#dictionary-attribute)(including the top-level - operation attribute dictionary) as well as attributes exposing it via - the `SubElementAttrInterface` interface. Symbol reference attributes - nested in types are currently not supported. - **Rationale:** Identifying accesses to global data is critical to enabling efficient multi-threaded compilation. Restricting global data access to occur through symbols and limiting the places that can @@ -1171,9 +1157,7 @@ // TypeAttr //===----------------------------------------------------------------------===// -def Builtin_TypeAttr : Builtin_Attr<"Type", [ - SubElementAttrInterface - ]> { +def Builtin_TypeAttr : Builtin_Attr<"Type"> { let summary = "An Attribute containing a Type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td --- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td +++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td @@ -15,7 +15,6 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinDialect.td" -include "mlir/IR/SubElementInterfaces.td" // Base class for Builtin dialect location attributes. class Builtin_LocationAttr traits = []> @@ -28,9 +27,7 @@ // CallSiteLoc //===----------------------------------------------------------------------===// -def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [ - SubElementAttrInterface - ]> { +def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc"> { let summary = "A callsite source location"; let description = [{ Syntax: @@ -107,9 +104,7 @@ // FusedLoc //===----------------------------------------------------------------------===// -def FusedLoc : Builtin_LocationAttr<"FusedLoc", [ - SubElementAttrInterface - ]> { +def FusedLoc : Builtin_LocationAttr<"FusedLoc"> { let summary = "A tuple of other source locations"; let description = [{ Syntax: @@ -148,9 +143,7 @@ // NameLoc //===----------------------------------------------------------------------===// -def NameLoc : Builtin_LocationAttr<"NameLoc", [ - SubElementAttrInterface - ]> { +def NameLoc : Builtin_LocationAttr<"NameLoc"> { let summary = "A named source location"; let description = [{ Syntax: @@ -187,9 +180,7 @@ // OpaqueLoc //===----------------------------------------------------------------------===// -def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [ - SubElementAttrInterface - ]> { +def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> { let summary = "An opaque source location"; let description = [{ An instance of this location essentially contains a pointer to some data diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -11,7 +11,6 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/SubElementInterfaces.h" namespace llvm { class BitVector; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -17,7 +17,6 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should @@ -165,9 +164,7 @@ // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function", [ - DeclareTypeInterfaceMethods - ]> { +def Builtin_Function : Builtin_Type<"Function"> { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -314,7 +311,7 @@ //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", [ - DeclareTypeInterfaceMethods, ShapedTypeInterface + ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ @@ -649,7 +646,7 @@ //===----------------------------------------------------------------------===// def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ - DeclareTypeInterfaceMethods, ShapedTypeInterface + ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ @@ -753,9 +750,7 @@ // TupleType //===----------------------------------------------------------------------===// -def Builtin_Tuple : Builtin_Type<"Tuple", [ - DeclareTypeInterfaceMethods - ]> { +def Builtin_Tuple : Builtin_Type<"Tuple"> { let summary = "Fixed-sized collection of other types"; let description = [{ Syntax: @@ -823,7 +818,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ - DeclareTypeInterfaceMethods, ShapedTypeInterface + ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ @@ -895,7 +890,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ - DeclareTypeInterfaceMethods, ShapedTypeInterface + ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ @@ -943,9 +938,7 @@ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", [ - DeclareTypeInterfaceMethods, ShapedTypeInterface - ], "Type"> { +def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -43,13 +43,6 @@ add_public_tablegen_target(MLIRFunctionInterfacesIncGen) add_dependencies(mlir-generic-headers MLIRFunctionInterfacesIncGen) -set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td) -mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs) -mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs) -add_public_tablegen_target(MLIRSubElementInterfacesIncGen) - set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -15,7 +15,6 @@ #define MLIR_IR_LOCATION_H #include "mlir/IR/Attributes.h" -#include "mlir/IR/SubElementInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { @@ -172,13 +171,13 @@ } //===----------------------------------------------------------------------===// -// SubElementInterfaces +// SubElements //===----------------------------------------------------------------------===// /// Enable locations to be introspected as sub-elements. template <> struct AttrTypeSubElementHandler { - static void walk(Location param, AttrTypeSubElementWalker &walker) { + static void walk(Location param, AttrTypeImmediateSubElementWalker &walker) { walker.walk(param); } static Location replace(Location param, AttrSubElementReplacements &attrRepls, diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H #define MLIR_IR_STORAGEUNIQUERSUPPORT_H +#include "mlir/IR/AttrTypeSubElements.h" #include "mlir/Support/InterfaceSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StorageUniquer.h" @@ -126,6 +127,51 @@ }; } + /// Walk all of the immediately nested sub-attributes and sub-types. This + /// method does not recurse into sub elements. + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const { + ::mlir::detail::walkImmediateSubElementsImpl( + *static_cast(this), walkAttrsFn, walkTypesFn); + } + + /// Replace the immediately nested sub-attributes and sub-types with those + /// provided. The order of the provided elements is derived from the order of + /// the elements returned by the callbacks of `walkImmediateSubElements`. The + /// element at index 0 would replace the very first attribute given by + /// `walkImmediateSubElements`. On success, the new instance with the values + /// replaced is returned. If replacement fails, nullptr is returned. + /// + /// Note that replacing the sub-elements of mutable types or attributes is + /// not currently supported by the interface. If an implementing type or + /// attribute is mutable, it should return `nullptr` if it has no mechanism + /// for replacing sub elements. + auto replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return ::mlir::detail::replaceImmediateSubElementsImpl( + *static_cast(this), replAttrs, replTypes); + } + + /// Returns a function that walks immediate sub elements of a given instance + /// of the storage user. + static auto getWalkImmediateSubElementsFn() { + return [](auto instance, function_ref walkAttrsFn, + function_ref walkTypesFn) { + cast(instance).walkImmediateSubElements(walkAttrsFn, + walkTypesFn); + }; + } + + /// Returns a function that replaces immediate sub elements of a given + /// instance of the storage user. + static auto getReplaceImmediateSubElementsFn() { + return [](auto instance, ArrayRef replAttrs, + ArrayRef replTypes) { + return cast(instance).replaceImmediateSubElements(replAttrs, + replTypes); + }; + } + /// Attach the given models as implementations of the corresponding interfaces /// for the concrete storage user class. The type must be registered with the /// context, i.e. the dialect to which the type belongs must be loaded. The diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td deleted file mode 100644 --- a/mlir/include/mlir/IR/SubElementInterfaces.td +++ /dev/null @@ -1,142 +0,0 @@ -//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains a set of interfaces that can be used to interface with -// sub-elements, e.g. held attributes and types, of a composite attribute or -// type. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_ -#define MLIR_IR_SUBELEMENTINTERFACES_TD_ - -include "mlir/IR/OpBase.td" - -//===----------------------------------------------------------------------===// -// SubElementInterfaceBase -//===----------------------------------------------------------------------===// - -class SubElementInterfaceBase { - string cppNamespace = "::mlir"; - - list methods = [ - InterfaceMethod< - /*desc=*/[{ - Walk all of the immediately nested sub-attributes and sub-types. This - method does not recurse into sub elements. - }], "void", "walkImmediateSubElements", - (ins "llvm::function_ref":$walkAttrsFn, - "llvm::function_ref":$walkTypesFn), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ - ::mlir::detail::walkImmediateSubElementsImpl( - }] # derivedValue # [{, walkAttrsFn, walkTypesFn); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Replace the immediately nested sub-attributes and sub-types with those provided. - The order of the provided elements is derived from the order of the elements - returned by the callbacks of `walkImmediateSubElements`. The element at index 0 - would replace the very first attribute given by `walkImmediateSubElements`. - On success, the new instance with the values replaced is returned. If replacement - fails, nullptr is returned. - - Note that replacing the sub-elements of mutable types or attributes is - not currently supported by the interface. If an implementing type or - attribute is mutable, it should return `nullptr` if it has no mechanism - for replacing sub elements. - }], attrOrType, "replaceImmediateSubElements", - (ins "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs, - "::llvm::ArrayRef<::mlir::Type>":$replTypes), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return ::mlir::detail::replaceImmediateSubElementsImpl( - }] # derivedValue # [{, replAttrs, replTypes); - }]>, - ]; - - code extraClassDeclaration = [{ - /// Walk all of the held sub-attributes and sub-types. - void walkSubElements(llvm::function_ref walkAttrsFn, - llvm::function_ref walkTypesFn); - - /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. Returns nullptr in the case of failure. See - /// `AttrTypeReplacer` for information on the support replacement function types. - template - }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) { - AttrTypeReplacer replacer; - (replacer.addReplacement(std::forward(replacementFns)), ...); - return replacer.replace(*this); - } - }]; - code extraTraitClassDeclaration = [{ - /// Walk all of the held sub-attributes and sub-types. - void walkSubElements(llvm::function_ref walkAttrsFn, - llvm::function_ref walkTypesFn) { - }] # interfaceName # " interface(" # derivedValue # [{); - interface.walkSubElements(walkAttrsFn, walkTypesFn); - } - - /// Recursively replace all of the nested sub-attributes and sub-types using the - /// provided map functions. Returns nullptr in the case of failure. See - /// `AttrTypeReplacer` for information on the support replacement function types. - template - }] # attrOrType # [{ replaceSubElements(ReplacementFns &&... replacementFns) { - AttrTypeReplacer replacer; - (replacer.addReplacement(std::forward(replacementFns)), ...); - return replacer.replace(}] # derivedValue # [{); - } - }]; - code extraSharedClassDeclaration = [{ - /// Walk all of the held sub-attributes. - void walkSubAttrs(llvm::function_ref walkFn) { - walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); - } - /// Walk all of the held sub-types. - void walkSubTypes(llvm::function_ref walkFn) { - walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); - } - }]; -} - -//===----------------------------------------------------------------------===// -// SubElementAttrInterface -//===----------------------------------------------------------------------===// - -def SubElementAttrInterface - : AttrInterface<"SubElementAttrInterface">, - SubElementInterfaceBase<"SubElementAttrInterface", "::mlir::Attribute", - "$_attr"> { - let description = [{ - An interface used to query and manipulate sub-elements, such as sub-types - and sub-attributes of a composite attribute. - - To support the introspection of custom parameters that hold sub-elements, - a specialization of the `AttrTypeSubElementHandler` class must be provided. - }]; -} - -//===----------------------------------------------------------------------===// -// SubElementTypeInterface -//===----------------------------------------------------------------------===// - -def SubElementTypeInterface - : TypeInterface<"SubElementTypeInterface">, - SubElementInterfaceBase<"SubElementTypeInterface", "::mlir::Type", - "$_type"> { - let description = [{ - An interface used to query and manipulate sub-elements, such as sub-types - and sub-attributes of a composite type. - - To support the introspection of custom parameters that hold sub-elements, - a specialization of the `AttrTypeSubElementHandler` class must be provided. - }]; -} - -#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -166,13 +166,13 @@ } //===----------------------------------------------------------------------===// -// SubElementInterfaces +// SubElements //===----------------------------------------------------------------------===// /// Enable TypeRange to be introspected for sub-elements. template <> struct AttrTypeSubElementHandler { - static void walk(TypeRange param, AttrTypeSubElementWalker &walker) { + static void walk(TypeRange param, AttrTypeImmediateSubElementWalker &walker) { walker.walkRange(param); } static TypeRange replace(TypeRange param, diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -30,6 +30,10 @@ class AbstractType { public: using HasTraitFn = llvm::unique_function; + using WalkImmediateSubElementsFn = function_ref, function_ref)>; + using ReplaceImmediateSubElementsFn = + function_ref, ArrayRef)>; /// Look up the specified abstract type in the MLIRContext and return a /// reference to it. @@ -40,17 +44,23 @@ template static AbstractType get(Dialect &dialect) { return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), - T::getTypeID()); + T::getWalkImmediateSubElementsFn(), + T::getReplaceImmediateSubElementsFn(), T::getTypeID()); } /// This method is used by Dialect objects to register types with /// custom TypeIDs. /// The use of this method is in general discouraged in favor of /// 'get(dialect)'; - static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) { + static AbstractType + get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, + HasTraitFn &&hasTrait, + WalkImmediateSubElementsFn walkImmediateSubElementsFn, + ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, + TypeID typeID) { return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), - typeID); + walkImmediateSubElementsFn, + replaceImmediateSubElementsFn, typeID); } /// Return the dialect this type was registered to. @@ -78,14 +88,29 @@ /// Returns true if the type has a particular trait. bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } + /// Walk the immediate sub-elements of the given type. + void walkImmediateSubElements(Type type, + function_ref walkAttrsFn, + function_ref walkTypesFn) const; + + /// Replace the immediate sub-elements of the given type. + Type replaceImmediateSubElements(Type type, ArrayRef replAttrs, + ArrayRef replTypes) const; + /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } private: AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) + HasTraitFn &&hasTrait, + WalkImmediateSubElementsFn walkImmediateSubElementsFn, + ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, + TypeID typeID) : dialect(dialect), interfaceMap(std::move(interfaceMap)), - hasTraitFn(std::move(hasTrait)), typeID(typeID) {} + hasTraitFn(std::move(hasTrait)), + walkImmediateSubElementsFn(walkImmediateSubElementsFn), + replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), + typeID(typeID) {} /// Give StorageUserBase access to the mutable lookup. template getAbstractType(); } + const AbstractTy &getAbstractType() const { return impl->getAbstractType(); } /// Return the Type implementation. ImplType *getImpl() const { return impl; } + /// Walk all of the immediately nested sub-attributes and sub-types. This + /// method does not recurse into sub elements. + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const { + getAbstractType().walkImmediateSubElements(*this, walkAttrsFn, walkTypesFn); + } + + /// Replace the immediately nested sub-attributes and sub-types with those + /// provided. The order of the provided elements is derived from the order of + /// the elements returned by the callbacks of `walkImmediateSubElements`. The + /// element at index 0 would replace the very first attribute given by + /// `walkImmediateSubElements`. On success, the new instance with the values + /// replaced is returned. If replacement fails, nullptr is returned. + auto replaceImmediateSubElements(ArrayRef replAttrs, + ArrayRef replTypes) const { + return getAbstractType().replaceImmediateSubElements(*this, replAttrs, + replTypes); + } + + /// Walk this type and all attibutes/types nested within using the + /// provided walk functions. See `AttrTypeWalker` for information on the + /// supported walk function types. + template + auto walk(WalkFns &&...walkFns) { + AttrTypeWalker walker; + (walker.addWalk(std::forward(walkFns)), ...); + return walker.walk(*this); + } + + /// Recursively replace all of the nested sub-attributes and sub-types using + /// the provided map functions. Returns nullptr in the case of failure. See + /// `AttrTypeReplacer` for information on the support replacement function + /// types. + template + auto replace(ReplacementFns &&...replacementFns) { + AttrTypeReplacer replacer; + (replacer.addReplacement(std::forward(replacementFns)), + ...); + return replacer.replace(*this); + } + protected: ImplType *impl{nullptr}; }; diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -35,7 +35,7 @@ enum ResultEnum { Interrupt, Advance, Skip } result; public: - WalkResult(ResultEnum result) : result(result) {} + WalkResult(ResultEnum result = Advance) : result(result) {} /// Allow LogicalResult to interrupt the walk on failure. WalkResult(LogicalResult result) diff --git a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp --- a/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/LowerMemorySpaceAttributes.cpp @@ -80,20 +80,15 @@ void mlir::gpu::populateMemorySpaceAttributeTypeConversions( TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { - typeConverter.addConversion([mapping](Type type) -> std::optional { - auto subElementType = type.dyn_cast_or_null(); - if (!subElementType) - return type; - Type newType = subElementType.replaceSubElements( - [mapping](Attribute attr) -> std::optional { - auto memorySpaceAttr = attr.dyn_cast_or_null(); - if (!memorySpaceAttr) - return std::nullopt; - auto newValue = wrapNumericMemorySpace( - attr.getContext(), mapping(memorySpaceAttr.getValue())); - return newValue; - }); - return newType; + typeConverter.addConversion([mapping](Type type) { + return type.replace([mapping](Attribute attr) -> std::optional { + auto memorySpaceAttr = attr.dyn_cast_or_null(); + if (!memorySpaceAttr) + return std::nullopt; + auto newValue = wrapNumericMemorySpace( + attr.getContext(), mapping(memorySpaceAttr.getValue())); + return newValue; + }); }); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" @@ -841,13 +840,11 @@ } // For most builtin types, we can simply walk the sub elements. - if (auto subElementInterface = dyn_cast(type)) { - auto visitFn = [&](auto element) { - if (element) - (void)printAlias(element); - }; - subElementInterface.walkImmediateSubElements(visitFn, visitFn); - } + auto visitFn = [&](auto element) { + if (element) + (void)printAlias(element); + }; + type.walkImmediateSubElements(visitFn, visitFn); } /// Consider the given type to be printed for an alias. diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/AttrTypeSubElements.cpp rename from mlir/lib/IR/SubElementInterfaces.cpp rename to mlir/lib/IR/AttrTypeSubElements.cpp --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/AttrTypeSubElements.cpp @@ -1,4 +1,4 @@ -//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===// +//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,96 +6,77 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Operation.h" - -#include "llvm/ADT/DenseSet.h" #include using namespace mlir; //===----------------------------------------------------------------------===// -// SubElementInterface +// AttrTypeWalker //===----------------------------------------------------------------------===// -//===----------------------------------------------------------------------===// -// WalkSubElements - -template -static void walkSubElementsImpl(InterfaceT interface, - function_ref walkAttrsFn, - function_ref walkTypesFn, - DenseSet &visitedAttrs, - DenseSet &visitedTypes) { - interface.walkImmediateSubElements( - [&](Attribute attr) { - // Guard against potentially null inputs. This removes the need for the - // derived attribute/type to do it. - if (!attr) - return; - - // Avoid infinite recursion when visiting sub attributes later, if this - // is a mutable attribute. - if (LLVM_UNLIKELY(attr.hasTrait())) { - if (!visitedAttrs.insert(attr).second) - return; - } - - // Walk any sub elements first. - if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, - visitedTypes); +WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) { + return walkImpl(attr, attrWalkFns, order); +} +WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) { + return walkImpl(type, typeWalkFns, order); +} - // Walk this attribute. - walkAttrsFn(attr); - }, - [&](Type type) { - // Guard against potentially null inputs. This removes the need for the - // derived attribute/type to do it. - if (!type) - return; - - // Avoid infinite recursion when visiting sub types later, if this - // is a mutable type. - if (LLVM_UNLIKELY(type.hasTrait())) { - if (!visitedTypes.insert(type).second) - return; - } +template +WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns, + WalkOrder order) { + // Check if we've already walk this element before. + auto key = std::make_pair(element.getAsOpaquePointer(), (int)order); + auto it = visitedAttrTypes.find(key); + if (it != visitedAttrTypes.end()) + return it->second; + visitedAttrTypes.try_emplace(key, WalkResult::advance()); - // Walk any sub elements first. - if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, - visitedTypes); + // If we are walking in post order, walk the sub elements first. + if (order == WalkOrder::PostOrder) { + if (walkSubElements(element, order).wasInterrupted()) + return visitedAttrTypes[key] = WalkResult::interrupt(); + } - // Walk this type. - walkTypesFn(type); - }); -} + // Walk this element, bailing if skipped or interrupted. + for (auto &walkFn : llvm::reverse(walkFns)) { + WalkResult walkResult = walkFn(element); + if (walkResult.wasInterrupted()) + return visitedAttrTypes[key] = WalkResult::interrupt(); + if (walkResult.wasSkipped()) + return WalkResult::advance(); + } -void SubElementAttrInterface::walkSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) { - assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - DenseSet visitedAttrs; - DenseSet visitedTypes; - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, - visitedTypes); + // If we are walking in pre-order, walk the sub elements last. + if (order == WalkOrder::PreOrder) { + if (walkSubElements(element, order).wasInterrupted()) + return WalkResult::interrupt(); + } + return WalkResult::advance(); } -void SubElementTypeInterface::walkSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) { - assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - DenseSet visitedAttrs; - DenseSet visitedTypes; - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, - visitedTypes); +template +WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) { + WalkResult result = WalkResult::advance(); + auto walkFn = [&](auto element) { + if (element && !result.wasInterrupted()) + result = walkImpl(element, order); + }; + interface.walkImmediateSubElements(walkFn, walkFn); + return result.wasInterrupted() ? result : WalkResult::advance(); } //===----------------------------------------------------------------------===// /// AttrTypeReplacer //===----------------------------------------------------------------------===// +void AttrTypeReplacer::addReplacement(ReplaceFn fn) { + attrReplacementFns.emplace_back(std::move(fn)); +} +void AttrTypeReplacer::addReplacement(ReplaceFn fn) { + typeReplacementFns.push_back(std::move(fn)); +} + void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { // Functor that replaces the given element if the new value is different, @@ -157,7 +138,6 @@ template static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, - DenseMap &elementMap, SmallVectorImpl &newElements, FailureOr &changed) { // Bail early if we failed at any point. @@ -180,19 +160,18 @@ } } -template -T AttrTypeReplacer::replaceSubElements(InterfaceT interface, - DenseMap &interfaceMap) { +template +T AttrTypeReplacer::replaceSubElements(T interface) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; interface.walkImmediateSubElements( [&](Attribute element) { - updateSubElementImpl(element, *this, attrMap, newAttrs, changed); + updateSubElementImpl(element, *this, newAttrs, changed); }, [&](Type element) { - updateSubElementImpl(element, *this, typeMap, newTypes, changed); + updateSubElementImpl(element, *this, newTypes, changed); }); if (failed(changed)) return nullptr; @@ -205,12 +184,12 @@ } /// Shared implementation of replacing a given attribute or type element. -template -T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns, - DenseMap &map) { - auto [it, inserted] = map.try_emplace(element, element); +template +T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) { + const void *opaqueElement = element.getAsOpaquePointer(); + auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement); if (!inserted) - return it->second; + return T::getFromOpaquePointer(it->second); T result = element; WalkResult walkResult = WalkResult::advance(); @@ -222,34 +201,42 @@ } // If an error occurred, return nullptr to indicate failure. - if (walkResult.wasInterrupted() || !result) - return map[element] = nullptr; + if (walkResult.wasInterrupted() || !result) { + attrTypeMap[opaqueElement] = nullptr; + return nullptr; + } // Handle replacing sub-elements if this element is also a container. if (!walkResult.wasSkipped()) { - if (auto interface = dyn_cast(result)) { - // Replace the sub elements of this element, bailing if we fail. - if (!(result = replaceSubElements(interface, map))) - return map[element] = nullptr; + // Replace the sub elements of this element, bailing if we fail. + if (!(result = replaceSubElements(result))) { + attrTypeMap[opaqueElement] = nullptr; + return nullptr; } } - return map[element] = result; + attrTypeMap[opaqueElement] = result.getAsOpaquePointer(); + return result; } Attribute AttrTypeReplacer::replace(Attribute attr) { - return replaceImpl(attr, attrReplacementFns, - attrMap); + return replaceImpl(attr, attrReplacementFns); } Type AttrTypeReplacer::replace(Type type) { - return replaceImpl(type, typeReplacementFns, - typeMap); + return replaceImpl(type, typeReplacementFns); } //===----------------------------------------------------------------------===// -// SubElementInterface Tablegen definitions +// AttrTypeImmediateSubElementWalker //===----------------------------------------------------------------------===// -#include "mlir/IR/SubElementAttrInterfaces.cpp.inc" -#include "mlir/IR/SubElementTypeInterfaces.cpp.inc" +void AttrTypeImmediateSubElementWalker::walk(Attribute element) { + if (element) + walkAttrsFn(element); +} + +void AttrTypeImmediateSubElementWalker::walk(Type element) { + if (element) + walkTypesFn(element); +} diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -12,6 +12,23 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +// AbstractAttribute +//===----------------------------------------------------------------------===// + +void AbstractAttribute::walkImmediateSubElements( + Attribute attr, function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkImmediateSubElementsFn(attr, walkAttrsFn, walkTypesFn); +} + +Attribute +AbstractAttribute::replaceImmediateSubElements(Attribute attr, + ArrayRef replAttrs, + ArrayRef replTypes) const { + return replaceImmediateSubElementsFn(attr, replAttrs, replTypes); +} + //===----------------------------------------------------------------------===// // Attribute //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -3,6 +3,7 @@ AffineMap.cpp AsmPrinter.cpp Attributes.cpp + AttrTypeSubElements.cpp Block.cpp Builders.cpp BuiltinAttributeInterfaces.cpp @@ -26,7 +27,6 @@ PatternMatch.cpp Region.cpp RegionKindInterface.cpp - SubElementInterfaces.cpp SymbolTable.cpp TensorEncoding.cpp Types.cpp @@ -54,7 +54,6 @@ MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSideEffectInterfacesIncGen - MLIRSubElementInterfacesIncGen MLIRSymbolInterfacesIncGen MLIRTensorEncodingIncGen diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -407,9 +407,10 @@ assert(registered && "Trying to create a new dynamic type with an existing name"); - auto abstractType = - AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(), - DynamicType::getHasTraitFn(), typeID); + auto abstractType = AbstractType::get( + *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(), + DynamicType::getWalkImmediateSubElementsFn(), + DynamicType::getReplaceImmediateSubElementsFn(), typeID); /// Add the type to the dialect and the type uniquer. addType(typeID, std::move(abstractType)); @@ -436,9 +437,10 @@ assert(registered && "Trying to create a new dynamic attribute with an existing name"); - auto abstractAttr = - AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(), - DynamicAttr::getHasTraitFn(), typeID); + auto abstractAttr = AbstractAttribute::get( + *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(), + DynamicAttr::getWalkImmediateSubElementsFn(), + DynamicAttr::getReplaceImmediateSubElementsFn(), typeID); /// Add the type to the dialect and the type uniquer. addAttribute(typeID, std::move(abstractAttr)); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -485,66 +485,14 @@ static WalkResult walkSymbolRefs(Operation *op, function_ref callback) { - // Check to see if the operation has any attributes. - DictionaryAttr attrDict = op->getAttrDictionary(); - if (attrDict.empty()) - return WalkResult::advance(); - - // A worklist of a container attribute and the current index into the held - // attribute list. - struct WorklistItem { - SubElementAttrInterface container; - SmallVector immediateSubElements; - - explicit WorklistItem(SubElementAttrInterface container) { - SmallVector subElements; - container.walkImmediateSubElements( - [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {}); - immediateSubElements = std::move(subElements); - } - }; - - SmallVector attrWorklist(1, WorklistItem(attrDict)); - SmallVector curAccessChain(1, /*Value=*/-1); - - // Process the symbol references within the given nested attribute range. - auto processAttrs = [&](int &index, - WorklistItem &worklistItem) -> WalkResult { - for (Attribute attr : - llvm::drop_begin(worklistItem.immediateSubElements, index)) { - // Invoke the provided callback if we find a symbol use and check for a - // requested interrupt. - if (auto symbolRef = attr.dyn_cast()) { + return op->getAttrDictionary().walk( + [&](SymbolRefAttr symbolRef) { if (callback({op, symbolRef}).wasInterrupted()) return WalkResult::interrupt(); - /// Check for a nested container attribute, these will also need to be - /// walked. - } else if (auto interface = attr.dyn_cast()) { - attrWorklist.emplace_back(interface); - curAccessChain.push_back(-1); - return WalkResult::advance(); - } - // Make sure to keep the index counter in sync. - ++index; - } - - // Pop this container attribute from the worklist. - attrWorklist.pop_back(); - curAccessChain.pop_back(); - return WalkResult::advance(); - }; - - WalkResult result = WalkResult::advance(); - do { - WorklistItem &item = attrWorklist.back(); - int &index = curAccessChain.back(); - ++index; - - // Process the given attribute, which is guaranteed to be a container. - result = processAttrs(index, item); - } while (!attrWorklist.empty() && !result.wasInterrupted()); - return result; + // Don't walk nested references. + return WalkResult::skip(); + }); } /// Walk all of the uses, for any symbol, that are nested within the given diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -31,7 +31,7 @@ : width(width), signedness(signedness) {} /// The hash key used for uniquing. - using KeyTy = std::pair; + using KeyTy = std::tuple; static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_value(key); @@ -44,7 +44,7 @@ static IntegerTypeStorage *construct(TypeStorageAllocator &allocator, KeyTy key) { return new (allocator.allocate()) - IntegerTypeStorage(key.first, key.second); + IntegerTypeStorage(std::get<0>(key), std::get<1>(key)); } KeyTy getAsKey() const { return KeyTy(width, signedness); } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -12,6 +12,22 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +// AbstractType +//===----------------------------------------------------------------------===// + +void AbstractType::walkImmediateSubElements( + Type type, function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkImmediateSubElementsFn(type, walkAttrsFn, walkTypesFn); +} + +Type AbstractType::replaceImmediateSubElements(Type type, + ArrayRef replAttrs, + ArrayRef replTypes) const { + return replaceImmediateSubElementsFn(type, replAttrs, replTypes); +} + //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir --- a/mlir/test/IR/test-symbol-rauw.mlir +++ b/mlir/test/IR/test-symbol-rauw.mlir @@ -76,7 +76,7 @@ // ----- -// Check that replacement works in any implementations of SubElementsAttrInterface +// Check that replacement works in any implementations of SubElements. module { // CHECK: func private @replaced_foo func.func private @symbol_foo() attributes {sym.new_name = "replaced_foo" } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -21,7 +21,6 @@ include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/SubElementInterfaces.td" // All of the attributes will extend this class. class Test_Attr traits = []> @@ -120,9 +119,7 @@ let hasCustomAssemblyFormat = 1; } -def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ - SubElementAttrInterface - ]> { +def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess"> { let mnemonic = "sub_elements_access"; let parameters = (ins diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -22,7 +22,6 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -132,7 +131,6 @@ class TestRecursiveType : public ::mlir::Type::TypeBase { public: using Base::Base; diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp --- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -8,7 +8,6 @@ #include "LLVMTestBase.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/SubElementInterfaces.h" using namespace mlir; using namespace mlir::LLVM; @@ -31,33 +30,24 @@ Type barBody[] = {LLVMPointerType::get(fooStructTy)}; ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*isPacked=*/false))); - auto subElementInterface = fooStructTy.dyn_cast(); - ASSERT_TRUE(bool(subElementInterface)); // Test if walkSubElements goes into infinite loops. SmallVector subElementTypes; - subElementInterface.walkSubElements( - [](Attribute attr) {}, - [&](Type type) { subElementTypes.push_back(type); }); - // We don't record LLVMPointerType (because it's immutable), thus - // !llvm.ptr> will be visited twice. - ASSERT_EQ(subElementTypes.size(), 5U); + fooStructTy.walk([&](Type type) { subElementTypes.push_back(type); }); + ASSERT_EQ(subElementTypes.size(), 4U); - // !llvm.ptr> + // !llvm.ptr> ASSERT_TRUE(subElementTypes[0].isa()); - // !llvm.struct<"foo",...> + // !llvm.struct<"bar",...> auto structType = subElementTypes[1].dyn_cast(); ASSERT_TRUE(bool(structType)); - ASSERT_TRUE(structType.getName().equals("foo")); + ASSERT_TRUE(structType.getName().equals("bar")); - // !llvm.ptr> + // !llvm.ptr> ASSERT_TRUE(subElementTypes[2].isa()); - // !llvm.struct<"bar",...> + // !llvm.struct<"foo",...> structType = subElementTypes[3].dyn_cast(); ASSERT_TRUE(bool(structType)); - ASSERT_TRUE(structType.getName().equals("bar")); - - // !llvm.ptr> - ASSERT_TRUE(subElementTypes[4].isa()); + ASSERT_TRUE(structType.getName().equals("foo")); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -422,4 +422,25 @@ EXPECT_TRUE(zeroStringValue.getType() == stringTy); } +//===----------------------------------------------------------------------===// +// SubElements +//===----------------------------------------------------------------------===// + +TEST(SubElementTest, Nested) { + MLIRContext context; + Builder builder(&context); + + BoolAttr trueAttr = builder.getBoolAttr(true); + BoolAttr falseAttr = builder.getBoolAttr(false); + ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr}); + StringAttr strAttr = builder.getStringAttr("array"); + DictionaryAttr dictAttr = + builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); + + SmallVector subAttrs; + dictAttr.walk([&](Attribute attr) { subAttrs.push_back(attr); }); + EXPECT_EQ(llvm::ArrayRef(subAttrs), + ArrayRef( + {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); +} } // namespace diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -8,7 +8,6 @@ OperationSupportTest.cpp PatternMatchTest.cpp ShapedTypeTest.cpp - SubElementInterfaceTest.cpp TypeTest.cpp DEPENDS diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -41,24 +41,6 @@ EXPECT_FALSE(opSet.contains(op3)); } -TEST(InterfaceTest, AttrInterfaceDenseMapKey) { - MLIRContext context; - context.loadDialect(); - - OpBuilder builder(&context); - - DenseSet attrSet; - auto attr1 = builder.getArrayAttr({}); - auto attr2 = builder.getI32ArrayAttr({0}); - auto attr3 = builder.getI32ArrayAttr({1}); - attrSet.insert(attr1); - attrSet.insert(attr2); - attrSet.erase(attr1); - EXPECT_FALSE(attrSet.contains(attr1)); - EXPECT_TRUE(attrSet.contains(attr2)); - EXPECT_FALSE(attrSet.contains(attr3)); -} - TEST(InterfaceTest, TypeInterfaceDenseMapKey) { MLIRContext context; context.loadDialect(); diff --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp deleted file mode 100644 --- a/mlir/unittests/IR/SubElementInterfaceTest.cpp +++ /dev/null @@ -1,36 +0,0 @@ -//===- SubElementInterfaceTest.cpp - SubElementInterface unit tests -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/SubElementInterfaces.h" -#include "gtest/gtest.h" -#include - -using namespace mlir; -using namespace mlir::detail; - -namespace { -TEST(SubElementInterfaceTest, Nested) { - MLIRContext context; - Builder builder(&context); - - BoolAttr trueAttr = builder.getBoolAttr(true); - BoolAttr falseAttr = builder.getBoolAttr(false); - ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr}); - StringAttr strAttr = builder.getStringAttr("array"); - DictionaryAttr dictAttr = - builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); - - SmallVector subAttrs; - dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); }); - EXPECT_EQ(llvm::ArrayRef(subAttrs), - ArrayRef({strAttr, trueAttr, falseAttr, boolArrayAttr})); -} - -} // namespace