diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir --- a/flang/test/Fir/affine-promotion.fir +++ b/flang/test/Fir/affine-promotion.fir @@ -50,21 +50,21 @@ // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK: %[[VAL_4:.*]] = arith.constant 100 : index // CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> -// CHECK: %[[VAL_6:.*]] = affine.apply #map(){{\[}}%[[VAL_3]], %[[VAL_4]]] +// CHECK: %[[VAL_6:.*]] = affine.apply #{{.*}}(){{\[}}%[[VAL_3]], %[[VAL_4]]] // CHECK: %[[VAL_7:.*]] = fir.alloca !fir.array, %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> memref // CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref>) -> memref // CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref>) -> memref -// CHECK: affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] { -// CHECK: %[[VAL_12:.*]] = affine.apply #map2(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] +// CHECK: affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_4]]] { +// CHECK: %[[VAL_12:.*]] = affine.apply #{{.*}}(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] // CHECK: %[[VAL_13:.*]] = affine.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_14:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 // CHECK: affine.store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref // CHECK: } // CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>) -> memref -// CHECK: affine.for %[[VAL_17:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] { -// CHECK: %[[VAL_18:.*]] = affine.apply #map2(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] +// CHECK: affine.for %[[VAL_17:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_4]]] { +// CHECK: %[[VAL_18:.*]] = affine.apply #{{.*}}(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]] // CHECK: %[[VAL_19:.*]] = affine.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_20:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32 @@ -114,18 +114,18 @@ // CHECK: %[[VAL_5:.*]] = arith.constant 100 : index // CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> // CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> memref -// CHECK: affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] { -// CHECK: %[[VAL_9:.*]] = affine.apply #map1(%[[VAL_8]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] +// CHECK: affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] { +// CHECK: %[[VAL_9:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] // CHECK: affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_9]]] : memref // CHECK: } -// CHECK: affine.for %[[VAL_10:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] { -// CHECK: %[[VAL_11:.*]] = affine.apply #map1(%[[VAL_10]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] +// CHECK: affine.for %[[VAL_10:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] { +// CHECK: %[[VAL_11:.*]] = affine.apply #{{.*}}(%[[VAL_10]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] // CHECK: affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref // CHECK: } -// CHECK: affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #map(){{\[}}%[[VAL_5]]] { +// CHECK: affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #{{.*}}(){{\[}}%[[VAL_5]]] { // CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_12]], %[[VAL_4]] : index // CHECK: affine.if #set(%[[VAL_12]]) { -// CHECK: %[[VAL_14:.*]] = affine.apply #map1(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] +// CHECK: %[[VAL_14:.*]] = affine.apply #{{.*}}(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]] // CHECK: affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref // CHECK: } // CHECK: } diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -959,6 +959,8 @@ - Provide a method to hash an instance of the `KeyTy`. (Note: This is not necessary if an `llvm::DenseMapInfo` specialization exists) - `static llvm::hash_code hashKey(const KeyTy &)` +- Provide a method to generate the `KeyTy` from an instance of the storage class. + - `static KeyTy getAsKey()` Let's look at an example: @@ -997,6 +999,11 @@ ComplexTypeStorage(key.first, key.second); } + /// Construct an instance of the key from this storage class. + KeyTy getAsKey() const { + return KeyTy(nonZeroParam, integerType); + } + /// The parametric data held by the storage class. unsigned nonZeroParam; Type integerType; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -151,7 +151,7 @@ //===----------------------------------------------------------------------===// def LLVM_DICompileUnitAttr : LLVM_Attr<"DICompileUnit", "di_compile_unit", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DIScopeAttr"> { let parameters = (ins LLVM_DILanguageParameter:$sourceLanguage, @@ -168,7 +168,7 @@ //===----------------------------------------------------------------------===// def LLVM_DICompositeTypeAttr : LLVM_Attr<"DICompositeType", "di_composite_type", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DITypeAttr"> { let parameters = (ins LLVM_DITagParameter:$tag, @@ -188,7 +188,7 @@ //===----------------------------------------------------------------------===// def LLVM_DIDerivedTypeAttr : LLVM_Attr<"DIDerivedType", "di_derived_type", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DITypeAttr"> { let parameters = (ins LLVM_DITagParameter:$tag, @@ -220,7 +220,7 @@ //===----------------------------------------------------------------------===// def LLVM_DILexicalBlockAttr : LLVM_Attr<"DILexicalBlock", "di_lexical_block", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DIScopeAttr"> { let parameters = (ins "DIScopeAttr":$scope, @@ -244,7 +244,7 @@ //===----------------------------------------------------------------------===// def LLVM_DILexicalBlockFile : LLVM_Attr<"DILexicalBlockFile", "di_lexical_block_file", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DIScopeAttr"> { let parameters = (ins "DIScopeAttr":$scope, @@ -266,7 +266,7 @@ //===----------------------------------------------------------------------===// def LLVM_DILocalVariableAttr : LLVM_Attr<"DILocalVariable", "di_local_variable", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DINodeAttr"> { let parameters = (ins "DIScopeAttr":$scope, @@ -296,7 +296,7 @@ //===----------------------------------------------------------------------===// def LLVM_DISubprogramAttr : LLVM_Attr<"DISubprogram", "di_subprogram", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DIScopeAttr"> { let parameters = (ins "DICompileUnitAttr":$compileUnit, @@ -346,7 +346,7 @@ //===----------------------------------------------------------------------===// def LLVM_DISubroutineTypeAttr : LLVM_Attr<"DISubroutineType", "di_subroutine_type", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ], "DITypeAttr"> { let parameters = (ins LLVM_DICallingConventionParameter:$callingConvention, diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -72,7 +72,7 @@ //===----------------------------------------------------------------------===// def Builtin_ArrayAttr : Builtin_Attr<"Array", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "A collection of other Attribute values"; let description = [{ @@ -510,7 +510,7 @@ //===----------------------------------------------------------------------===// def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "An dictionary of named Attribute values"; let description = [{ @@ -1115,7 +1115,7 @@ //===----------------------------------------------------------------------===// def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "An Attribute containing a symbolic reference to an Operation"; let description = [{ @@ -1190,7 +1190,7 @@ //===----------------------------------------------------------------------===// def Builtin_TypeAttr : Builtin_Attr<"Type", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "An Attribute containing a Type"; let description = [{ diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td --- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td +++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td @@ -29,7 +29,7 @@ //===----------------------------------------------------------------------===// def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "A callsite source location"; let description = [{ @@ -108,7 +108,7 @@ //===----------------------------------------------------------------------===// def FusedLoc : Builtin_LocationAttr<"FusedLoc", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "A tuple of other source locations"; let description = [{ @@ -149,7 +149,7 @@ //===----------------------------------------------------------------------===// def NameLoc : Builtin_LocationAttr<"NameLoc", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "A named source location"; let description = [{ @@ -188,7 +188,7 @@ //===----------------------------------------------------------------------===// def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let summary = "An opaque source location"; let description = [{ diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -107,6 +107,9 @@ return LocationAttr(reinterpret_cast(pointer)); } + /// Support llvm style casting. + static bool classof(Attribute attr) { return llvm::isa(attr); } + protected: /// The internal backing location attribute. LocationAttr impl; @@ -167,6 +170,23 @@ return get(reinterpret_cast(underlyingLocation), TypeID::get(), UnknownLoc::get(context)); } + +//===----------------------------------------------------------------------===// +// SubElementInterfaces +//===----------------------------------------------------------------------===// + +/// Enable locations to be introspected as sub-elements. +template <> +struct AttrTypeSubElementHandler { + static void walk(Location param, AttrTypeSubElementWalker &walker) { + walker.walk(param); + } + static Location replace(Location param, AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + return cast(attrRepls.take_front(1)[0]); + } +}; + } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -180,6 +180,9 @@ return ConcreteT((const typename BaseT::ImplType *)ptr); } + /// Utility for easy access to the storage instance. + ImplType *getImpl() const { return static_cast(this->impl); } + protected: /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. @@ -199,9 +202,6 @@ return success(); } - /// Utility for easy access to the storage instance. - ImplType *getImpl() const { return static_cast(this->impl); } - private: /// Trait to check if T provides a 'ConcreteEntity' type alias. template diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h --- a/mlir/include/mlir/IR/SubElementInterfaces.h +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -23,6 +23,253 @@ using SubElementReplFn = function_ref; template using SubElementResultReplFn = function_ref(T)>; + +//===----------------------------------------------------------------------===// +/// AttrTypeSubElementHandler +//===----------------------------------------------------------------------===// + +/// This class is used by AttrTypeSubElementHandler instances to walking sub +/// attributes and types. +class AttrTypeSubElementWalker { +public: + AttrTypeSubElementWalker(function_ref walkAttrsFn, + function_ref walkTypesFn) + : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {} + + /// Walk an attribute. + void walk(Attribute element) { + if (element) + walkAttrsFn(element); + } + /// Walk a type. + void walk(Type element) { + if (element) + walkTypesFn(element); + } + /// Walk a range of attributes or types. + template + void walkRange(RangeT &&elements) { + for (auto element : elements) + walk(element); + } + +private: + function_ref walkAttrsFn; + function_ref walkTypesFn; +}; + +/// This class is used by AttrTypeSubElementHandler instances to process sub +/// element replacements. +template +class AttrTypeSubElementReplacements { +public: + AttrTypeSubElementReplacements(ArrayRef repls) : repls(repls) {} + + /// Take the first N replacements as an ArrayRef, dropping them from + /// this replacement list. + ArrayRef take_front(unsigned n) { + ArrayRef elements = repls.take_front(n); + repls = repls.drop_front(n); + return elements; + } + +private: + /// The current set of replacements. + ArrayRef repls; +}; +using AttrSubElementReplacements = AttrTypeSubElementReplacements; +using TypeSubElementReplacements = AttrTypeSubElementReplacements; + +/// This class provides support for interacting with the +/// SubElementInterfaces for different types of parameters. An +/// implementation of this class should be provided for any parameter class +/// that may contain an attribute or type. There are two main methods of +/// this class that need to be implemented: +/// +/// - walk +/// +/// This method should traverse into any sub elements of the parameter +/// using the provided walker, or by invoking handlers for sub-types. +/// +/// - replace +/// +/// This method should extract any necessary sub elements using the +/// provided replacer, or by invoking handlers for sub-types. The new +/// post-replacement parameter value should be returned. +/// +template +struct AttrTypeSubElementHandler { + /// Default walk implementation that does nothing. + static inline void walk(const T ¶m, AttrTypeSubElementWalker &walker) {} + + /// Default replace implementation just forwards the parameter. + template + static inline decltype(auto) replace(ParamT &¶m, + AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + return std::forward(param); + } + + /// Tag indicating that this handler does not support sub-elements. + using DefaultHandlerTag = void; +}; + +/// Detect if any of the given parameter types has a sub-element handler. +namespace detail { +template +using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag); +} // namespace detail +template +inline constexpr bool has_sub_attr_or_type_v = + (!llvm::is_detected::value || + ...); + +/// Implementation for derived Attributes and Types. +template +struct AttrTypeSubElementHandler< + T, std::enable_if_t || + std::is_base_of_v>> { + static void walk(T param, AttrTypeSubElementWalker &walker) { + walker.walk(param); + } + static T replace(T param, AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + if (!param) + return T(); + if constexpr (std::is_base_of_v) { + return cast(attrRepls.take_front(1)[0]); + } else if constexpr (!detail::IsInterface::value && + std::is_base_of_v) { + return cast(typeRepls.take_front(1)[0]); + } + } +}; +template <> +struct AttrTypeSubElementHandler { + template + static void walk(T param, AttrTypeSubElementWalker &walker) { + walker.walk(param.getName()); + walker.walk(param.getValue()); + } + template + static T replace(T param, AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + ArrayRef paramRepls = attrRepls.take_front(2); + return T(cast(paramRepls[0]), paramRepls[1]); + } +}; +/// Implementation for derived ArrayRef. +template +struct AttrTypeSubElementHandler, + std::enable_if_t>> { + using EltHandler = AttrTypeSubElementHandler; + + static void walk(ArrayRef param, AttrTypeSubElementWalker &walker) { + for (const T &subElement : param) + EltHandler::walk(subElement, walker); + } + static auto replace(ArrayRef param, AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + // Normal attributes/types can extract using the replacer directly. + if constexpr (std::is_base_of_v && + sizeof(T) == sizeof(Attribute)) { + ArrayRef attrs = attrRepls.take_front(param.size()); + return ArrayRef((const T *)attrs.data(), attrs.size()); + } else if constexpr (std::is_base_of_v && + sizeof(T) == sizeof(Type)) { + ArrayRef types = typeRepls.take_front(param.size()); + return ArrayRef((const T *)types.data(), types.size()); + } else { + // Otherwise, we need to allocate storage for the new elements. + SmallVector newElements; + for (const T &element : param) + newElements.emplace_back( + EltHandler::replace(element, attrRepls, typeRepls)); + return newElements; + } + } +}; +/// Implementation for Tuple. +template +struct AttrTypeSubElementHandler< + std::tuple, std::enable_if_t>> { + static void walk(const std::tuple ¶m, + AttrTypeSubElementWalker &walker) { + std::apply( + [&](auto &&...params) { + (AttrTypeSubElementHandler::walk(params, walker), ...); + }, + param); + } + static auto replace(const std::tuple ¶m, + AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + return std::apply( + [&](const Ts &...params) + -> std::tuple::replace( + params, attrRepls, typeRepls))...> { + return {AttrTypeSubElementHandler::replace(params, attrRepls, + typeRepls)...}; + }, + param); + } +}; + +namespace detail { +template +struct is_tuple : public std::false_type {}; +template +struct is_tuple> : public std::true_type {}; + +/// This function provides the underlying implementation for the +/// SubElementInterface walk method, using the key type of the derived +/// attribute/type to interact with the individual parameters. +template +void walkImmediateSubElementsImpl(T derived, + function_ref walkAttrsFn, + function_ref walkTypesFn) { + auto key = static_cast(derived.getImpl())->getAsKey(); + + // If we don't have any sub-elements, there is nothing to do. + if constexpr (!has_sub_attr_or_type_v) { + return; + } else { + AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn); + AttrTypeSubElementHandler::walk(key, walker); + } +} + +/// This function provides the underlying implementation for the +/// SubElementInterface replace method, using the key type of the derived +/// attribute/type to interact with the individual parameters. +template +T replaceImmediateSubElementsImpl(T derived, ArrayRef &replAttrs, + ArrayRef &replTypes) { + auto key = static_cast(derived.getImpl())->getAsKey(); + + // If we don't have any sub-elements, we can just return the original. + if constexpr (!has_sub_attr_or_type_v) { + return derived; + + // Otherwise, we need to replace any necessary sub-elements. + } else { + AttrSubElementReplacements attrRepls(replAttrs); + TypeSubElementReplacements typeRepls(replTypes); + auto newKey = AttrTypeSubElementHandler::replace( + key, attrRepls, typeRepls); + if constexpr (is_tuple::value) { + return std::apply( + [&](auto &&...params) { + return T::Base::get(derived.getContext(), + std::forward(params)...); + }, + newKey); + } else { + return T::Base::get(derived.getContext(), newKey); + } + } +} +} // namespace detail } // namespace mlir /// Include the definitions of the sub elemnt interfaces. diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td --- a/mlir/include/mlir/IR/SubElementInterfaces.td +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -32,7 +32,11 @@ method does not recurse into sub elements. }], "void", "walkImmediateSubElements", (ins "llvm::function_ref":$walkAttrsFn, - "llvm::function_ref":$walkTypesFn) + "llvm::function_ref":$walkTypesFn), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + ::mlir::detail::walkImmediateSubElementsImpl( + }] # derivedValue # [{, walkAttrsFn, walkTypesFn); + }] >, InterfaceMethod< /*desc=*/[{ @@ -47,10 +51,13 @@ not currently supported by the interface. If an implementing type or attribute is mutable, it should return `nullptr` if it has no mechanism for replacing sub elements. - }], attrOrType, "replaceImmediateSubElements", (ins - "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs, - "::llvm::ArrayRef<::mlir::Type>":$replTypes - )>, + }], attrOrType, "replaceImmediateSubElements", + (ins "::llvm::ArrayRef<::mlir::Attribute>":$replAttrs, + "::llvm::ArrayRef<::mlir::Type>":$replTypes), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return ::mlir::detail::replaceImmediateSubElementsImpl( + }] # derivedValue # [{, replAttrs, replTypes); + }]>, ]; code extraClassDeclaration = [{ @@ -154,6 +161,9 @@ let description = [{ An interface used to query and manipulate sub-elements, such as sub-types and sub-attributes of a composite attribute. + + To support the introspection of custom parameters that hold sub-elements, + a specialization of the `AttrTypeSubElementHandler` class must be provided. }]; } @@ -168,6 +178,9 @@ let description = [{ An interface used to query and manipulate sub-elements, such as sub-types and sub-attributes of a composite type. + + To support the introspection of custom parameters that hold sub-elements, + a specialization of the `AttrTypeSubElementHandler` class must be provided. }]; } diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -165,6 +165,23 @@ std::equal(lhs.begin(), lhs.end(), rhs.begin()); } +//===----------------------------------------------------------------------===// +// SubElementInterfaces +//===----------------------------------------------------------------------===// + +/// Enable TypeRange to be introspected for sub-elements. +template <> +struct AttrTypeSubElementHandler { + static void walk(TypeRange param, AttrTypeSubElementWalker &walker) { + walker.walkRange(param); + } + static TypeRange replace(TypeRange param, + AttrSubElementReplacements &attrRepls, + TypeSubElementReplacements &typeRepls) { + return typeRepls.take_front(param.size()); + } +}; + } // namespace mlir namespace llvm { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -64,169 +64,6 @@ return llvm::isa(attr); } -//===----------------------------------------------------------------------===// -// DICompileUnitAttr -//===----------------------------------------------------------------------===// - -void DICompileUnitAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getFile()); - walkAttrsFn(getProducer()); -} - -Attribute -DICompileUnitAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getContext(), getSourceLanguage(), replAttrs[0].cast(), - replAttrs[1].cast(), getIsOptimized(), - getEmissionKind()); -} - -//===----------------------------------------------------------------------===// -// DICompositeTypeAttr -//===----------------------------------------------------------------------===// - -void DICompositeTypeAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getName()); - walkAttrsFn(getFile()); - walkAttrsFn(getScope()); - for (DINodeAttr element : getElements()) - walkAttrsFn(element); -} - -Attribute DICompositeTypeAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - ArrayRef elements = replAttrs.drop_front(3); - return get( - getContext(), getTag(), replAttrs[0].cast(), - cast_or_null(replAttrs[1]), getLine(), - cast_or_null(replAttrs[2]), getSizeInBits(), - getAlignInBits(), - ArrayRef(static_cast(elements.data()), - elements.size())); -} - -//===----------------------------------------------------------------------===// -// DIDerivedTypeAttr -//===----------------------------------------------------------------------===// - -void DIDerivedTypeAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getName()); - walkAttrsFn(getBaseType()); -} - -Attribute -DIDerivedTypeAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getContext(), getTag(), replAttrs[0].cast(), - replAttrs[1].cast(), getSizeInBits(), getAlignInBits(), - getOffsetInBits()); -} - -//===----------------------------------------------------------------------===// -// DILexicalBlockAttr -//===----------------------------------------------------------------------===// - -void DILexicalBlockAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getScope()); - walkAttrsFn(getFile()); -} - -Attribute DILexicalBlockAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replAttrs[0].cast(), replAttrs[1].cast(), - getLine(), getColumn()); -} - -//===----------------------------------------------------------------------===// -// DILexicalBlockFileAttr -//===----------------------------------------------------------------------===// - -void DILexicalBlockFileAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getScope()); - walkAttrsFn(getFile()); -} - -Attribute DILexicalBlockFileAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replAttrs[0].cast(), replAttrs[1].cast(), - getDescriminator()); -} - -//===----------------------------------------------------------------------===// -// DILocalVariableAttr -//===----------------------------------------------------------------------===// - -void DILocalVariableAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getScope()); - walkAttrsFn(getName()); - walkAttrsFn(getFile()); - walkAttrsFn(getType()); -} - -Attribute DILocalVariableAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(getContext(), replAttrs[0].cast(), - replAttrs[1].cast(), replAttrs[2].cast(), - getLine(), getArg(), getAlignInBits(), - replAttrs[3].cast()); -} - -//===----------------------------------------------------------------------===// -// DISubprogramAttr -//===----------------------------------------------------------------------===// - -void DISubprogramAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getCompileUnit()); - walkAttrsFn(getScope()); - walkAttrsFn(getName()); - walkAttrsFn(getLinkageName()); - walkAttrsFn(getFile()); - walkAttrsFn(getType()); -} - -Attribute -DISubprogramAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getContext(), replAttrs[0].cast(), - replAttrs[1].cast(), replAttrs[2].cast(), - replAttrs[3].cast(), replAttrs[4].cast(), - getLine(), getScopeLine(), getSubprogramFlags(), - replAttrs[5].cast()); -} - -//===----------------------------------------------------------------------===// -// DISubroutineTypeAttr -//===----------------------------------------------------------------------===// - -void DISubroutineTypeAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (DITypeAttr type : getTypes()) - walkAttrsFn(type); -} - -Attribute DISubroutineTypeAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get( - getContext(), getCallingConvention(), - ArrayRef(static_cast(replAttrs.data()), - replAttrs.size())); -} - //===----------------------------------------------------------------------===// // LoopOptionsAttrBuilder //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -183,20 +183,6 @@ return dataLayout.getTypePreferredAlignment(getElementType()); } -//===----------------------------------------------------------------------===// -// SubElementTypeInterface - -void LLVMArrayType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type LLVMArrayType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes.front(), getNumElements()); -} - //===----------------------------------------------------------------------===// // Function type. //===----------------------------------------------------------------------===// @@ -247,22 +233,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SubElementTypeInterface - -void LLVMFunctionType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getReturnType()); - for (Type type : getParams()) - walkTypesFn(type); -} - -Type LLVMFunctionType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes.front(), replTypes.drop_front(), isVarArg()); -} - //===----------------------------------------------------------------------===// // LLVMPointerType //===----------------------------------------------------------------------===// @@ -439,20 +409,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SubElementTypeInterface - -void LLVMPointerType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type LLVMPointerType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(getContext(), replTypes.front(), getAddressSpace()); -} - //===----------------------------------------------------------------------===// // Struct type. //===----------------------------------------------------------------------===// @@ -749,17 +705,6 @@ emitError, elementType, numElements); } -void LLVMFixedVectorType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type LLVMFixedVectorType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes[0], getNumElements()); -} - //===----------------------------------------------------------------------===// // LLVMScalableVectorType. //===----------------------------------------------------------------------===// @@ -792,17 +737,6 @@ emitError, elementType, numElements); } -void LLVMScalableVectorType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type LLVMScalableVectorType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes[0], getMinNumElements()); -} - //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -43,23 +43,6 @@ >(); } -//===----------------------------------------------------------------------===// -// ArrayAttr -//===----------------------------------------------------------------------===// - -void ArrayAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (Attribute attr : getValue()) - walkAttrsFn(attr); -} - -Attribute -ArrayAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getContext(), replAttrs); -} - //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -217,25 +200,6 @@ return Base::get(context, ArrayRef()); } -void DictionaryAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (const NamedAttribute &attr : getValue()) - walkAttrsFn(attr.getValue()); -} - -Attribute -DictionaryAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - std::vector vec = getValue().vec(); - for (auto &it : llvm::enumerate(replAttrs)) - vec[it.index()].setValue(it.value()); - - // The above only modifies the mapped value, but not the key, and therefore - // not the order of the elements. It remains sorted - return getWithSorted(getContext(), vec); -} - //===----------------------------------------------------------------------===// // StridedLayoutAttr //===----------------------------------------------------------------------===// @@ -375,24 +339,6 @@ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); } -void SymbolRefAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getRootReference()); - for (FlatSymbolRefAttr ref : getNestedReferences()) - walkAttrsFn(ref); -} - -Attribute -SymbolRefAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - ArrayRef rawNestedRefs = replAttrs.drop_front(); - ArrayRef nestedRefs( - static_cast(rawNestedRefs.data()), - rawNestedRefs.size()); - return get(replAttrs[0].cast(), nestedRefs); -} - //===----------------------------------------------------------------------===// // IntegerAttr //===----------------------------------------------------------------------===// @@ -1812,22 +1758,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// TypeAttr -//===----------------------------------------------------------------------===// - -void TypeAttr::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getValue()); -} - -Attribute -TypeAttr::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(replTypes[0]); -} - //===----------------------------------------------------------------------===// // Attribute Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -187,20 +187,6 @@ return clone(newArgTypes, newResultTypes); } -void FunctionType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (Type type : llvm::concat(getInputs(), getResults())) - walkTypesFn(type); -} - -Type FunctionType::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - unsigned numInputs = getNumInputs(); - return get(getContext(), replTypes.take_front(numInputs), - replTypes.drop_front(numInputs)); -} - //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -258,17 +244,6 @@ return VectorType(); } -void VectorType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type VectorType::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getShape(), replTypes.front(), getNumScalableDims()); -} - VectorType VectorType::cloneWith(Optional> shape, Type elementType) const { return VectorType::get(shape.value_or(getShape()), elementType, @@ -343,20 +318,6 @@ return checkTensorElementType(emitError, elementType); } -void RankedTensorType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); - if (Attribute encoding = getEncoding()) - walkAttrsFn(encoding); -} - -Type RankedTensorType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(getShape(), replTypes.front(), - replAttrs.empty() ? Attribute() : replAttrs.back()); -} - //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -367,17 +328,6 @@ return checkTensorElementType(emitError, elementType); } -void UnrankedTensorType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -Type UnrankedTensorType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes.front()); -} - //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -671,24 +621,6 @@ return success(); } -void MemRefType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); - if (!getLayout().isIdentity()) - walkAttrsFn(getLayout()); - walkAttrsFn(getMemorySpace()); -} - -Type MemRefType::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - bool hasLayout = replAttrs.size() > 1; - return get(getShape(), replTypes[0], - hasLayout ? replAttrs[0].dyn_cast() - : MemRefLayoutAttrInterface(), - hasLayout ? replAttrs[1] : replAttrs[0]); -} - //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -870,18 +802,6 @@ return success(); } -void UnrankedMemRefType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkTypesFn(getElementType()); - walkAttrsFn(getMemorySpace()); -} - -Type UnrankedMemRefType::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - return get(replTypes.front(), replAttrs.front()); -} - //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// @@ -905,18 +825,6 @@ /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } -void TupleType::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (Type type : getTypes()) - walkTypesFn(type); -} - -Type TupleType::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getContext(), replTypes); -} - //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -80,20 +80,6 @@ return CallSiteLoc::get(name, caller); } -void CallSiteLoc::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getCallee()); - walkAttrsFn(getCaller()); -} - -Attribute -CallSiteLoc::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(replAttrs[0].cast(), - replAttrs[1].cast()); -} - //===----------------------------------------------------------------------===// // FusedLoc //===----------------------------------------------------------------------===// @@ -135,55 +121,3 @@ return Base::get(context, locs, metadata); } - -void FusedLoc::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - for (Attribute attr : getLocations()) - walkAttrsFn(attr); - walkAttrsFn(getMetadata()); -} - -Attribute -FusedLoc::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - SmallVector newLocs; - newLocs.reserve(replAttrs.size() - 1); - for (Attribute attr : replAttrs.drop_back()) - newLocs.push_back(attr.cast()); - return get(getContext(), newLocs, replAttrs.back()); -} - -//===----------------------------------------------------------------------===// -// NameLoc -//===----------------------------------------------------------------------===// - -void NameLoc::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getName()); - walkAttrsFn(getChildLoc()); -} - -Attribute NameLoc::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(replAttrs[0].cast(), - replAttrs[1].cast()); -} - -//===----------------------------------------------------------------------===// -// OpaqueLoc -//===----------------------------------------------------------------------===// - -void OpaqueLoc::walkImmediateSubElements( - function_ref walkAttrsFn, - function_ref walkTypesFn) const { - walkAttrsFn(getFallbackLocation()); -} - -Attribute -OpaqueLoc::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const { - return get(getUnderlyingLocation(), getUnderlyingTypeID(), - replAttrs[0].cast()); -} diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp --- a/mlir/lib/IR/SubElementInterfaces.cpp +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -27,11 +27,6 @@ DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { - // Guard against potentially null inputs. This removes the need for the - // derived attribute/type to do it. - if (!attr) - return; - // Avoid infinite recursion when visiting sub attributes later, if this // is a mutable attribute. if (LLVM_UNLIKELY(attr.hasTrait())) { @@ -48,11 +43,6 @@ walkAttrsFn(attr); }, [&](Type type) { - // Guard against potentially null inputs. This removes the need for the - // derived attribute/type to do it. - if (!type) - return; - // Avoid infinite recursion when visiting sub types later, if this // is a mutable type. if (LLVM_UNLIKELY(type.hasTrait())) { @@ -103,10 +93,6 @@ return; newElements.push_back(element); - // Guard against potentially null inputs. We always map null to null. - if (!element) - return; - // Check for an existing mapping for this element, and walk it if we haven't // yet. T *mappedElement = &visited[element]; diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -47,6 +47,8 @@ IntegerTypeStorage(key.first, key.second); } + KeyTy getAsKey() const { return KeyTy(width, signedness); } + unsigned width : 30; IntegerType::SignednessSemantics signedness : 2; }; @@ -59,7 +61,7 @@ inputsAndResults(inputsAndResults) {} /// The hash key used for uniquing. - using KeyTy = std::pair; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { if (std::get<0>(key) == getInputs()) return std::get<1>(key) == getResults(); @@ -69,7 +71,7 @@ /// Construction. static FunctionTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - TypeRange inputs = key.first, results = key.second; + auto [inputs, results] = key; // Copy the inputs and results into the bump pointer. SmallVector types; @@ -90,6 +92,8 @@ return ArrayRef(inputsAndResults + numInputs, numResults); } + KeyTy getAsKey() const { return KeyTy(getInputs(), getResults()); } + unsigned numInputs; unsigned numResults; Type const *inputsAndResults; @@ -127,6 +131,8 @@ return {getTrailingObjects(), size()}; } + KeyTy getAsKey() const { return getTypes(); } + /// The number of tuple elements. unsigned numElements; }; diff --git a/mlir/test/Dialect/Affine/loop-tiling.mlir b/mlir/test/Dialect/Affine/loop-tiling.mlir --- a/mlir/test/Dialect/Affine/loop-tiling.mlir +++ b/mlir/test/Dialect/Affine/loop-tiling.mlir @@ -133,8 +133,8 @@ // CHECK: memref.dim %{{.*}}, %c0 : memref // CHECK-NEXT: affine.for %{{.*}} = 0 to %{{.*}} step 32 { // CHECK-NEXT: affine.for %{{.*}} = 0 to %{{.*}} step 32 { -// CHECK-NEXT: affine.for %{{.*}} = #map(%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] { -// CHECK-NEXT: affine.for %{{.*}} = #map(%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] { +// CHECK-NEXT: affine.for %{{.*}} = #[[$MAP:.*]](%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] { +// CHECK-NEXT: affine.for %{{.*}} = #[[$MAP]](%{{.*}}) to min [[$UBMAP]](%{{.*}})[%{{.*}}] { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref // CHECK-NEXT: affine.for %{{.*}} = 0 to %{{.*}} { // CHECK-NEXT: affine.load diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -775,9 +775,9 @@ return %shape : memref } -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> -// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @input_stays_same( // CHECK-SAME: %[[ARG0:.*]]: memref>, // CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref) diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -416,7 +416,7 @@ // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { // CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]) // CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG3]], %[[TMP1]]) -// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #map2(%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #{{.*}}(%[[ARG5]], %[[ARG6]]) // CHECK-NEXT: affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32> // ----- diff --git a/mlir/test/Dialect/SCF/for-loop-specialization.mlir b/mlir/test/Dialect/SCF/for-loop-specialization.mlir --- a/mlir/test/Dialect/SCF/for-loop-specialization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-specialization.mlir @@ -23,7 +23,7 @@ // CHECK: [[CST_0:%.*]] = arith.constant 0 : index // CHECK: [[CST_1:%.*]] = arith.constant 1 : index // CHECK: [[DIM_0:%.*]] = memref.dim [[ARG1]], [[CST_0]] : memref -// CHECK: [[MIN:%.*]] = affine.min #map(){{\[}}[[DIM_0]], [[ARG0]]] +// CHECK: [[MIN:%.*]] = affine.min #{{.*}}(){{\[}}[[DIM_0]], [[ARG0]]] // CHECK: [[CST_1024:%.*]] = arith.constant 1024 : index // CHECK: [[PRED:%.*]] = arith.cmpi eq, [[MIN]], [[CST_1024]] : index // CHECK: scf.if [[PRED]] { diff --git a/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir b/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-specialization.mlir @@ -26,8 +26,8 @@ // CHECK: [[VAL_7:%.*]] = arith.constant 1 : index // CHECK: [[VAL_8:%.*]] = memref.dim [[VAL_2]], [[VAL_6]] : memref // CHECK: [[VAL_9:%.*]] = memref.dim [[VAL_2]], [[VAL_7]] : memref -// CHECK: [[VAL_10:%.*]] = affine.min #map(){{\[}}[[VAL_8]], [[VAL_0]]] -// CHECK: [[VAL_11:%.*]] = affine.min #map1(){{\[}}[[VAL_9]], [[VAL_1]]] +// CHECK: [[VAL_10:%.*]] = affine.min #{{.*}}(){{\[}}[[VAL_8]], [[VAL_0]]] +// CHECK: [[VAL_11:%.*]] = affine.min #{{.*}}(){{\[}}[[VAL_9]], [[VAL_1]]] // CHECK: [[VAL_12:%.*]] = arith.constant 1024 : index // CHECK: [[VAL_13:%.*]] = arith.cmpi eq, [[VAL_10]], [[VAL_12]] : index // CHECK: [[VAL_14:%.*]] = arith.constant 64 : index diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir @@ -13,7 +13,7 @@ return } -// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> // CHECK-LABEL: func @parallel_loop( // CHECK-SAME: [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: index, [[ARG6:%.*]]: index, [[ARG7:%.*]]: memref, [[ARG8:%.*]]: memref, [[ARG9:%.*]]: memref, [[ARG10:%.*]]: memref) { // CHECK: [[C0:%.*]] = arith.constant 0 : index @@ -22,8 +22,8 @@ // CHECK: [[V1:%.*]] = arith.muli [[ARG5]], [[C1]] : index // CHECK: [[V2:%.*]] = arith.muli [[ARG6]], [[C4]] : index // CHECK: scf.parallel ([[V3:%.*]], [[V4:%.*]]) = ([[ARG1]], [[ARG2]]) to ([[ARG3]], [[ARG4]]) step ([[V1]], [[V2]]) { -// CHECK: [[V5:%.*]] = affine.min #map([[V1]], [[ARG3]], [[V3]]) -// CHECK: [[V6:%.*]] = affine.min #map([[V2]], [[ARG4]], [[V4]]) +// CHECK: [[V5:%.*]] = affine.min #[[$MAP]]([[V1]], [[ARG3]], [[V3]]) +// CHECK: [[V6:%.*]] = affine.min #[[$MAP]]([[V2]], [[ARG4]], [[V4]]) // CHECK: scf.parallel ([[V7:%.*]], [[V8:%.*]]) = ([[C0]], [[C0]]) to ([[V5]], [[V6]]) step ([[ARG5]], [[ARG6]]) { // CHECK: [[V9:%.*]] = arith.addi [[V7]], [[V3]] : index // CHECK: [[V10:%.*]] = arith.addi [[V8]], [[V4]] : index @@ -91,7 +91,7 @@ // CHECK: [[V3:%.*]] = arith.muli [[C1]], [[C1_1]] : index // CHECK: [[V4:%.*]] = arith.muli [[C1]], [[C4]] : index // CHECK: scf.parallel ([[V5:%.*]], [[V6:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V3]], [[V4]]) { -// CHECK: [[V7:%.*]] = affine.min #map([[V4]], [[C2]], [[V6]]) +// CHECK: [[V7:%.*]] = affine.min #{{.*}}([[V4]], [[C2]], [[V6]]) // CHECK: scf.parallel ([[V8:%.*]], [[V9:%.*]]) = ([[C0_1]], [[C0_1]]) to ([[V3]], [[V7]]) step ([[C1]], [[C1]]) { // CHECK: = arith.addi [[V8]], [[V5]] : index // CHECK: = arith.addi [[V9]], [[V6]] : index @@ -104,7 +104,7 @@ // CHECK: [[V10:%.*]] = arith.muli [[C1]], [[C1_2]] : index // CHECK: [[V11:%.*]] = arith.muli [[C1]], [[C4_1]] : index // CHECK: scf.parallel ([[V12:%.*]], [[V13:%.*]]) = ([[C0]], [[C0]]) to ([[C2]], [[C2]]) step ([[V10]], [[V11]]) { -// CHECK: [[V14:%.*]] = affine.min #map([[V11]], [[C2]], [[V13]]) +// CHECK: [[V14:%.*]] = affine.min #{{.*}}([[V11]], [[C2]], [[V13]]) // CHECK: scf.parallel ([[V15:%.*]], [[V16:%.*]]) = ([[C0_2]], [[C0_2]]) to ([[V10]], [[V14]]) step ([[C1]], [[C1]]) { // CHECK: = arith.addi [[V15]], [[V12]] : index // CHECK: = arith.addi [[V16]], [[V13]] : index diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -1,10 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s // Identity maps used in trivial compositions in MemRefs are optimized away. -// CHECK-NOT: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(i, j) -> (i, j)> - -// CHECK-NOT: #map{{[0-9]*}} = affine_map<(d0, d1)[s0] -> (d0, d1)> #map1 = affine_map<(i, j)[s0] -> (i, j)> // CHECK: #map{{[0-9]*}} = affine_map<() -> (0)> @@ -194,7 +191,6 @@ // Check if parser can parse affine_map with identifiers that collide with // integer types. -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> #map60 = affine_map<(i0, i1) -> (i0, i1)> // Check if parser can parse affine_map with identifiers that collide with diff --git a/mlir/test/IR/memory-ops.mlir b/mlir/test/IR/memory-ops.mlir --- a/mlir/test/IR/memory-ops.mlir +++ b/mlir/test/IR/memory-ops.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s | FileCheck %s -// CHECK: #map = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)> +// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (d0 + s0, d1)> // CHECK-LABEL: func @alloc() { func.func @alloc() { @@ -17,11 +17,11 @@ %1 = memref.alloc(%c0, %c1) : memref (d0, d1)>, 1> // Test alloc with no dynamic dimensions and one symbol. - // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #map, 1> + // CHECK: %{{.*}} = memref.alloc()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1> %2 = memref.alloc()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // Test alloc with dynamic dimensions and one symbol. - // CHECK: %{{.*}} = memref.alloc(%{{.*}})[%{{.*}}] : memref<2x?xf32, #map, 1> + // CHECK: %{{.*}} = memref.alloc(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1> %3 = memref.alloc(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> // Alloc with no mappings. @@ -48,11 +48,11 @@ %1 = memref.alloca(%c0, %c1) : memref (d0, d1)>, 1> // Test alloca with no dynamic dimensions and one symbol. - // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #map, 1> + // CHECK: %{{.*}} = memref.alloca()[%{{.*}}] : memref<2x4xf32, #[[$MAP]], 1> %2 = memref.alloca()[%c0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // Test alloca with dynamic dimensions and one symbol. - // CHECK: %{{.*}} = memref.alloca(%{{.*}})[%{{.*}}] : memref<2x?xf32, #map, 1> + // CHECK: %{{.*}} = memref.alloca(%{{.*}})[%{{.*}}] : memref<2x?xf32, #[[$MAP]], 1> %3 = memref.alloca(%c1)[%c0] : memref<2x?xf32, affine_map<(d0, d1)[s0] -> (d0 + s0, d1)>, 1> // Alloca with no mappings, but with alignment. diff --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir --- a/mlir/test/Transforms/loop-fusion-2.mlir +++ b/mlir/test/Transforms/loop-fusion-2.mlir @@ -508,16 +508,16 @@ } return } -// MAXIMAL: #map = affine_map<(d0, d1) -> (d0 * 16 + d1)> +// MAXIMAL: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)> // MAXIMAL-LABEL: func @fuse_across_dim_mismatch // MAXIMAL: memref.alloc() : memref<1x1xf32> // MAXIMAL: affine.for %{{.*}} = 0 to 9 { // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 { // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 { // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { -// MAXIMAL-NEXT: affine.apply #map(%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply #[[$MAP]](%{{.*}}, %{{.*}}) // MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32> -// MAXIMAL-NEXT: affine.apply #map(%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply #[[$MAP]](%{{.*}}, %{{.*}}) // MAXIMAL-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -29,15 +29,15 @@ // Same test with op_nonnorm, with maps in the arguments and the operations in the function. // CHECK-LABEL: test_nonnorm -// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #map>) +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>) func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { %0 = memref.alloc() : memref<1x16x14x14xf32, #map0> "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () memref.dealloc %0 : memref<1x16x14x14xf32, #map0> - // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #map> - // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map>, memref<1x16x14x14xf32, #map>) -> () - // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #map> + // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]> + // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> () + // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]> return } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -119,8 +119,7 @@ } def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ - DeclareAttrInterfaceMethods + SubElementAttrInterface ]> { let mnemonic = "sub_elements_access"; diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -150,20 +150,6 @@ << ">"; } -void TestSubElementsAccessAttr::walkImmediateSubElements( - llvm::function_ref walkAttrsFn, - llvm::function_ref walkTypesFn) const { - walkAttrsFn(getFirst()); - walkAttrsFn(getSecond()); - walkAttrsFn(getThird()); -} - -Attribute TestSubElementsAccessAttr::replaceImmediateSubElements( - ArrayRef replAttrs, ArrayRef replTypes) const { - assert(replAttrs.size() == 3 && "invalid number of replacement attributes"); - return get(getContext(), replAttrs[0], replAttrs[1], replAttrs[2]); -} - //===----------------------------------------------------------------------===// // TestExtern1DI64ElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -457,6 +457,13 @@ [&](auto ¶m) { os << param.getCppType(); }); os << '>'; storageCls->declare("KeyTy", std::move(os.str())); + + // Add a method to construct the key type from the storage. + Method *m = storageCls->addConstMethod("KeyTy", "getAsKey"); + m->body().indent() << "return KeyTy("; + llvm::interleaveComma(params, m->body().indent(), + [&](auto ¶m) { m->body() << param.getName(); }); + m->body() << ");"; } void DefGen::emitEquals() { diff --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp --- a/mlir/unittests/IR/SubElementInterfaceTest.cpp +++ b/mlir/unittests/IR/SubElementInterfaceTest.cpp @@ -23,13 +23,14 @@ BoolAttr trueAttr = builder.getBoolAttr(true); BoolAttr falseAttr = builder.getBoolAttr(false); ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr}); + StringAttr strAttr = builder.getStringAttr("array"); DictionaryAttr dictAttr = - builder.getDictionaryAttr(builder.getNamedAttr("array", boolArrayAttr)); + builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); SmallVector subAttrs; dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); }); EXPECT_EQ(llvm::makeArrayRef(subAttrs), - ArrayRef({trueAttr, falseAttr, boolArrayAttr})); + ArrayRef({strAttr, trueAttr, falseAttr, boolArrayAttr})); } } // namespace