diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -126,9 +126,9 @@ An integer attribute is a literal attribute that represents an integral value of the specified integer type. }]; - /// Here we've defined two parameters, one is the `self` type of the attribute - /// (i.e. the type of the Attribute itself), and the other is the integer value - /// of the attribute. + /// Here we've defined two parameters, one is a "self" type parameter, and the + /// other is the integer value of the attribute. The self type parameter is + /// specially handled by the assembly format. let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); /// Here we've defined a custom builder for the type, that removes the need to pass @@ -146,6 +146,8 @@ /// /// #my.int<50> : !my.int<32> // a 32-bit integer of value 50. /// + /// Note that the self type parameter is not included in the assembly format. + /// Its value is derived from the optional trailing type on all attributes. let assemblyFormat = "`<` $value `>`"; /// Indicate that our attribute will add additional verification to the parameters. @@ -271,9 +273,8 @@ - `ArrayRefOfSelfAllocationParameter` for arrays of objects which self-allocate as per the last specialization. -- `AttributeSelfTypeParameter` is a special AttrParameter that corresponds to - the `Type` of the attribute. Only one parameter of the attribute may be of - this parameter type. +- `AttributeSelfTypeParameter` is a special `AttrParameter` that represents + parameters derived from the optional trailing type on attributes. ### Traits @@ -702,6 +703,54 @@ DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctxt, 32)"> ``` +The value of parameters that appear __before__ the default-valued parameter in +the parameter declaration list are available as substitutions. E.g. + +```tablegen +let parameters = (ins + "IntegerAttr":$value, + DefaultValuedParameter<"Type", "$value.getType()">:$type +); +``` + +###### Attribute Self Type Parameter + +An attribute optionally has a trailing type after the assembly format of the +attribute value itself. MLIR parses over the attribute value and optionally +parses a colon-type before passing the `Type` into the dialect parser hook. + +``` +dialect-attribute ::= `#` dialect-namespace `<` attr-data `>` + (`:` type)? + | `#` alias-name pretty-dialect-sym-body? (`:` type)? +``` + +`AttributeSelfTypeParameter` is an attribute parameter specially handled by the +assembly format generator. Only one such parameter can be specified, and its +value is derived from the trailing type. This parameter's default value is +`NoneType::get($_ctxt)`. + +In order for the type to be printed by +MLIR, however, the attribute must implement `TypedAttrInterface`. For example, + +```tablegen +// This attribute has only a self type parameter. +def MyExternAttr : AttrDef { + let parameters = (AttributeSelfTypeParameter<"">:$type); + let mnemonic = "extern"; + let assemblyFormat = ""; +} +``` + +This attribute can look like: + +```mlir +#my_dialect.extern // none +#my_dialect.extern : i32 +#my_dialect.extern : tensor<4xi32> +#my_dialect.extern : !my_dialect.my_type +``` + ##### Assembly Format Directives Attribute and type assembly formats have the following directives: diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -15,6 +15,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" // Base class for Arithmetic dialect ops. Ops in this dialect have no side @@ -147,7 +148,7 @@ ``` }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins TypedAttrInterface:$value); // TODO: Disallow arith.constant to return anything other than a signless // integer or float like. Downstream users of Arithmetic should only be // working with signless integers, floats, or vectors/tensors thereof. diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -32,12 +32,12 @@ assert(operands.size() == 2 && "binary op takes two operands"); if (!operands[0] || !operands[1]) return {}; - if (operands[0].getType() != operands[1].getType()) - return {}; if (operands[0].isa() && operands[1].isa()) { auto lhs = operands[0].cast(); auto rhs = operands[1].cast(); + if (lhs.getType() != rhs.getType()) + return {}; auto calRes = calculate(lhs.getValue(), rhs.getValue()); @@ -53,6 +53,8 @@ // just fold based on the splat value. auto lhs = operands[0].cast(); auto rhs = operands[1].cast(); + if (lhs.getType() != rhs.getType()) + return {}; auto elementResult = calculate(lhs.getSplatValue(), rhs.getSplatValue()); @@ -66,6 +68,8 @@ // expanding the values. auto lhs = operands[0].cast(); auto rhs = operands[1].cast(); + if (lhs.getType() != rhs.getType()) + return {}; auto lhsIt = lhs.value_begin(); auto rhsIt = rhs.value_begin(); diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexAttributes.td @@ -10,18 +10,21 @@ #define COMPLEX_ATTRIBUTE include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/Dialect/Complex/IR/ComplexBase.td" //===----------------------------------------------------------------------===// // Complex Attributes. //===----------------------------------------------------------------------===// -class Complex_Attr - : AttrDef { +class Complex_Attr traits = []> + : AttrDef { let mnemonic = attrMnemonic; } -def Complex_NumberAttr : Complex_Attr<"Number", "number"> { +def Complex_NumberAttr : Complex_Attr<"Number", "number", + [TypedAttrInterface]> { let summary = "A complex number attribute"; let description = [{ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -139,7 +139,7 @@ ``` }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins TypedAttrInterface:$value); let results = (outs AnyType); let hasFolder = 1; @@ -212,7 +212,7 @@ ``` }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins TypedAttrInterface:$value); let results = (outs AnyType); let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td @@ -14,18 +14,19 @@ #define MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/Dialect/EmitC/IR/EmitCBase.td" //===----------------------------------------------------------------------===// // EmitC attribute definitions //===----------------------------------------------------------------------===// -class EmitC_Attr - : AttrDef { +class EmitC_Attr traits = []> + : AttrDef { let mnemonic = attrMnemonic; } -def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> { +def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque", [TypedAttrInterface]> { let summary = "An opaque attribute"; let description = [{ @@ -40,8 +41,9 @@ ``` }]; - let parameters = (ins StringRefParameter<"the opaque value">:$value); - + let parameters = (ins "Type":$type, + StringRefParameter<"the opaque value">:$value); + let hasCustomAssemblyFormat = 1; } diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" //===----------------------------------------------------------------------===// // Tablegen Attribute Declarations diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td @@ -10,6 +10,7 @@ #define MLPROGRAM_ATTRIBUTES include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/Dialect/MLProgram/IR/MLProgramBase.td" // Base class for MLProgram dialect attributes. @@ -22,7 +23,7 @@ // ExternAttr //===----------------------------------------------------------------------===// -def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> { +def MLProgram_ExternAttr : MLProgram_Attr<"Extern", [TypedAttrInterface]> { let summary = "Value used for a global signalling external resolution"; let description = [{ When used as the value for a GlobalOp, this indicates that the actual diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -16,6 +16,7 @@ #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -600,7 +601,7 @@ let arguments = (ins StrAttr:$sym_name, - AnyAttr:$default_value + TypedAttrInterface:$default_value ); let results = (outs); diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -257,14 +257,6 @@ let convertFromStorage = "$_self.cast<" # dialect.cppNamespace # "::" # cppClassName # ">()"; - // A code block used to build the value 'Type' of an Attribute when - // initializing its storage instance. This field is optional, and if not - // present the attribute will have its value type set to `NoneType`. This code - // block may reference any of the attributes parameters via - // `$_()">; @@ -334,7 +326,7 @@ // which by default is the C++ equality operator. The current MLIR context is // made available through `$_ctxt`, e.g., for constructing default values for // attributes and types. - string defaultValue = ?; + string defaultValue = ""; } class AttrParameter : AttrOrTypeParameter; @@ -392,11 +384,21 @@ let defaultValue = value; } -// This is a special parameter used for AttrDefs that represents a `mlir::Type` -// that is also used as the value `Type` of the attribute. Only one parameter -// of the attribute may be of this type. +// This is a special attribute parameter that represents the "self" type of the +// attribute. It is specially handled by the assembly format generator to derive +// its value from the optional trailing type after each attribute. +// +// By default, the self type parameter is optional and has a default value of +// `none`. If a derived type other than `::mlir::Type` is specified, the +// parameter loses its default value unless another one is specified by +// `typeBuilder`. class AttributeSelfTypeParameter : - AttrOrTypeParameter {} + string derivedType = "::mlir::Type", + string typeBuilder = ""> : + AttrOrTypeParameter { + let defaultValue = !if(!and(!empty(typeBuilder), + !eq(derivedType, "::mlir::Type")), + "::mlir::NoneType::get($_ctxt)", typeBuilder); +} #endif // ATTRTYPEBASE_TD diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -129,9 +129,6 @@ friend StorageUniquer; public: - /// Get the type of this attribute. - Type getType() const { return type; } - /// Return the abstract descriptor for this attribute. const AbstractAttribute &getAbstractAttribute() const { assert(abstractAttribute && "Malformed attribute storage object."); @@ -139,15 +136,6 @@ } protected: - /// Construct a new attribute storage instance with the given type. - /// Note: All attributes require a valid type. If no type is provided here, - /// the type of the attribute will automatically default to NoneType - /// upon initialization in the uniquer. - AttributeStorage(Type type = nullptr) : type(type) {} - - /// Set the type of this attribute. - void setType(Type newType) { type = newType; } - /// Set the abstract attribute for this storage instance. This is used by the /// AttributeUniquer when initializing a newly constructed storage object. void initializeAbstractAttribute(const AbstractAttribute &abstractAttr) { @@ -159,9 +147,6 @@ void initialize(MLIRContext *context) {} private: - /// The type of the attribute value. - Type type; - /// The abstract descriptor for this attribute. const AbstractAttribute *abstractAttribute = nullptr; }; diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -66,9 +66,6 @@ /// to support dynamic type casting. TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } - /// Return the type of this attribute. - Type getType() const { return impl->getType(); } - /// Return the context this attribute belongs to. MLIRContext *getContext() const; diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -11,6 +11,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Any.h" @@ -18,7 +19,6 @@ #include namespace mlir { -class ShapedType; //===----------------------------------------------------------------------===// // ElementsAttr @@ -237,10 +237,10 @@ public: using reference = typename IteratorT::reference; - ElementsAttrRange(Type shapeType, + ElementsAttrRange(ShapedType shapeType, const llvm::iterator_range &range) : llvm::iterator_range(range), shapeType(shapeType) {} - ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt) + ElementsAttrRange(ShapedType shapeType, IteratorT beginIt, IteratorT endIt) : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} /// Return the value at the given index. @@ -254,7 +254,7 @@ private: /// The shaped type of the parent ElementsAttr. - Type shapeType; + ShapedType shapeType; }; } // namespace detail diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -154,7 +154,10 @@ }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{ // By default, only check for a single element splat. return $_attr.getNumElements() == 1; - }]> + }]>, + InterfaceMethod<[{ + Returns the shaped type of the elements attribute. + }], "::mlir::ShapedType", "getType"> ]; string ElementsAttrInterfaceAccessors = [{ @@ -280,7 +283,7 @@ auto getValues() const { auto beginIt = $_attr.template value_begin(); return detail::ElementsAttrRange( - Attribute($_attr).getType(), beginIt, std::next(beginIt, size())); + $_attr.getType(), beginIt, std::next(beginIt, size())); } }] # ElementsAttrInterfaceAccessors; @@ -294,19 +297,17 @@ // Accessors //===------------------------------------------------------------------===// - /// Return the type of this attribute. - ShapedType getType() const; - /// Return the element type of this ElementsAttr. Type getElementType() const { return getElementType(*this); } - static Type getElementType(Attribute elementsAttr); + static Type getElementType(ElementsAttr elementsAttr); /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const { return isValidIndex(*this, index); } static bool isValidIndex(ShapedType type, ArrayRef index); - static bool isValidIndex(Attribute elementsAttr, ArrayRef index); + static bool isValidIndex(ElementsAttr elementsAttr, + ArrayRef index); /// Return the 1 dimensional flattened row-major index from the given /// multi-dimensional index. @@ -315,14 +316,14 @@ } static uint64_t getFlattenedIndex(Type type, ArrayRef index); - static uint64_t getFlattenedIndex(Attribute elementsAttr, + static uint64_t getFlattenedIndex(ElementsAttr elementsAttr, ArrayRef index) { return getFlattenedIndex(elementsAttr.getType(), index); } /// Returns the number of elements held by this attribute. int64_t getNumElements() const { return getNumElements(*this); } - static int64_t getNumElements(Attribute elementsAttr); + static int64_t getNumElements(ElementsAttr elementsAttr); //===------------------------------------------------------------------===// // Value Iteration @@ -349,7 +350,7 @@ /// Return the elements of this attribute as a value of type 'T'. template DefaultValueCheckT> getValues() const { - return {Attribute::getType(), value_begin(), value_end()}; + return {getType(), value_begin(), value_end()}; } template DefaultValueCheckT> value_begin() const; @@ -369,8 +370,8 @@ template > DerivedAttrValueIteratorRange getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {Attribute::getType(), llvm::map_range(getValues(), - static_cast(castFn))}; + return {getType(), llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttrValueIterator value_begin() const { @@ -388,10 +389,8 @@ /// return the iterable range. Otherwise, return llvm::None. template DefaultValueCheckT>> tryGetValues() const { - if (Optional> beginIt = try_value_begin()) { - return iterator_range(Attribute::getType(), *beginIt, - value_end()); - } + if (Optional> beginIt = try_value_begin()) + return iterator_range(getType(), *beginIt, value_end()); return llvm::None; } template @@ -407,7 +406,7 @@ auto castFn = [](Attribute attr) { return attr.template cast(); }; return DerivedAttrValueIteratorRange( - Attribute::getType(), + getType(), llvm::map_range(*values, static_cast(castFn)) ); } @@ -468,4 +467,23 @@ ]; } +//===----------------------------------------------------------------------===// +// TypedAttrInterface +//===----------------------------------------------------------------------===// + +def TypedAttrInterface : AttrInterface<"TypedAttr"> { + let cppNamespace = "::mlir"; + + let description = [{ + This interface is used for attributes that have a type. The type of an + attribute is understood to represent the type of the data contained in the + attribute and is often used as the type of a value with this data. + }]; + + let methods = [InterfaceMethod< + "Get the attribute's type", + "::mlir::Type", "getType" + >]; +} + #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -25,7 +25,6 @@ class IntegerType; class Location; class Operation; -class ShapedType; //===----------------------------------------------------------------------===// // Elements Attributes @@ -402,7 +401,7 @@ std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {Attribute::getType(), ElementIterator(rawData, splat, 0), + return {getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template > @@ -431,7 +430,7 @@ std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {Attribute::getType(), ElementIterator(rawData, splat, 0), + return {getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template (stringRefs.data()); bool splat = isSplat(); - return {Attribute::getType(), ElementIterator(ptr, splat, 0), + return {getType(), ElementIterator(ptr, splat, 0), ElementIterator(ptr, splat, getNumElements())}; } template > @@ -478,8 +477,7 @@ typename std::enable_if::value>::type; template > iterator_range_impl getValues() const { - return {Attribute::getType(), value_begin(), - value_end()}; + return {getType(), value_begin(), value_end()}; } template > AttributeElementIterator value_begin() const { @@ -510,7 +508,7 @@ template > iterator_range_impl> getValues() const { using DerivedIterT = DerivedAttributeElementIterator; - return {Attribute::getType(), DerivedIterT(value_begin()), + return {getType(), DerivedIterT(value_begin()), DerivedIterT(value_end())}; } template > @@ -530,7 +528,7 @@ template > iterator_range_impl getValues() const { assert(isValidBool() && "bool is not the value of this elements attribute"); - return {Attribute::getType(), BoolElementIterator(*this, 0), + return {getType(), BoolElementIterator(*this, 0), BoolElementIterator(*this, getNumElements())}; } template > @@ -552,7 +550,7 @@ template > iterator_range_impl getValues() const { assert(getElementType().isIntOrIndex() && "expected integral type"); - return {Attribute::getType(), raw_int_begin(), raw_int_end()}; + return {getType(), raw_int_begin(), raw_int_end()}; } template > IntElementIterator value_begin() const { @@ -991,8 +989,6 @@ } inline bool operator!=(StringRef lhs, StringAttr rhs) { return !(lhs == rhs); } -inline Type StringAttr::getType() const { return Attribute::getType(); } - } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -64,7 +64,6 @@ AffineMap getAffineMap() const { return getValue(); } }]; let skipDefaultBuilders = 1; - let typeBuilder = "IndexType::get($_value.getContext())"; } //===----------------------------------------------------------------------===// @@ -140,11 +139,11 @@ } //===----------------------------------------------------------------------===// -// DenseIntOrFPElementsAttr +// DenseArrayBaseAttr //===----------------------------------------------------------------------===// def Builtin_DenseArrayBase : Builtin_Attr< - "DenseArrayBase", [ElementsAttrInterface]> { + "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> { let summary = "A dense array of i8, i16, i32, i64, f32, or f64."; let description = [{ A dense array attribute is an attribute that represents a dense array of @@ -197,8 +196,12 @@ const float *value_begin_impl(OverloadToken) const; const double *value_begin_impl(OverloadToken) const; - /// Methods to support type inquiry through isa, cast, and dyn_cast. + /// Returns the shaped type, containing the number of elements in the array + /// and the array element type. + ShapedType getType() const; + /// Returns the element type. EltType getElementType() const; + /// Printer for the short form: will dispatch to the appropriate subclass. void print(AsmPrinter &printer) const; void print(raw_ostream &os) const; @@ -216,7 +219,8 @@ //===----------------------------------------------------------------------===// def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< - "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr" + "DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface], + "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " "integer or floating-point values"; @@ -355,7 +359,8 @@ //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr : Builtin_Attr< - "DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr" + "DenseStringElements", [ElementsAttrInterface, TypedAttrInterface], + "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; @@ -523,7 +528,7 @@ // FloatAttr //===----------------------------------------------------------------------===// -def Builtin_FloatAttr : Builtin_Attr<"Float"> { +def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> { let summary = "An Attribute containing a floating-point value"; let description = [{ Syntax: @@ -586,7 +591,7 @@ // IntegerAttr //===----------------------------------------------------------------------===// -def Builtin_IntegerAttr : Builtin_Attr<"Integer"> { +def Builtin_IntegerAttr : Builtin_Attr<"Integer", [TypedAttrInterface]> { let summary = "An Attribute containing a integer value"; let description = [{ Syntax: @@ -703,7 +708,7 @@ // OpaqueAttr //===----------------------------------------------------------------------===// -def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> { +def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> { let summary = "An opaque representation of another Attribute"; let description = [{ Syntax: @@ -741,7 +746,7 @@ //===----------------------------------------------------------------------===// def Builtin_OpaqueElementsAttr : Builtin_Attr< - "OpaqueElements", [ElementsAttrInterface] + "OpaqueElements", [ElementsAttrInterface, TypedAttrInterface] > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ @@ -803,7 +808,7 @@ //===----------------------------------------------------------------------===// def Builtin_SparseElementsAttr : Builtin_Attr< - "SparseElements", [ElementsAttrInterface] + "SparseElements", [ElementsAttrInterface, TypedAttrInterface] > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ @@ -966,7 +971,7 @@ // StringAttr //===----------------------------------------------------------------------===// -def Builtin_StringAttr : Builtin_Attr<"String"> { +def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> { let summary = "An Attribute containing a string"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h @@ -0,0 +1,14 @@ +//===- BuiltinTypeInterfaces.h - Builtin Type Interfaces --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BUILTINTYPEINTERFACES_H +#define MLIR_IR_BUILTINTYPEINTERFACES_H + +#include "mlir/IR/BuiltinTypeInterfaces.h.inc" + +#endif // MLIR_IR_BUILTINTYPEINTERFACES_H diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,8 +9,9 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H -#include "BuiltinAttributeInterfaces.h" -#include "SubElementInterfaces.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/SubElementInterfaces.h" namespace llvm { class BitVector; @@ -21,8 +22,6 @@ // Tablegen Interface Declarations //===----------------------------------------------------------------------===// -#include "mlir/IR/BuiltinTypeInterfaces.h.inc" - namespace mlir { class AffineExpr; class AffineMap; diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -215,8 +215,9 @@ /// Parse an extended attribute. /// /// extended-attribute ::= (dialect-attribute | attribute-alias) -/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` -/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? +/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>` +/// (`:` type)? +/// | `#` alias-name pretty-dialect-sym-body? (`:` type)? /// attribute-alias ::= `#` alias-name /// Attribute Parser::parseExtendedAttr(Type type) { @@ -250,9 +251,10 @@ }); // Ensure that the attribute has the same type as requested. - if (attr && type && attr.getType() != type) { + auto typedAttr = attr.dyn_cast_or_null(); + if (type && typedAttr && typedAttr.getType() != type) { emitError("attribute type different than expected: expected ") - << type << ", but got " << attr.getType(); + << type << ", but got " << typedAttr.getType(); return nullptr; } return attr; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -753,7 +753,10 @@ } MlirType mlirAttributeGetType(MlirAttribute attribute) { - return wrap(unwrap(attribute).getType()); + Attribute attr = unwrap(attribute); + if (auto typedAttr = attr.dyn_cast()) + return wrap(typedAttr.getType()); + return wrap(NoneType::get(attr.getContext())); } MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -395,8 +395,8 @@ return failure(); Attribute cstAttr = constOp.getValue(); - if (cstAttr.getType().isa()) - cstAttr = cstAttr.cast().getSplatValue(); + if (auto elementsAttr = cstAttr.dyn_cast()) + cstAttr = elementsAttr.getSplatValue(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -698,8 +698,8 @@ llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = - op.getValue().cast().getSplatValue(); + auto splat = + op.getValue().cast().getSplatValue(); auto scalarConstant = b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -128,7 +128,8 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. - if (value.getType() != type) + auto typedAttr = value.dyn_cast(); + if (!typedAttr || typedAttr.getType() != type) return false; // Integer values must be signless. if (type.isa() && !type.cast().isSignless()) diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -30,11 +30,13 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) { if (auto arrAttr = value.dyn_cast()) { auto complexTy = type.dyn_cast(); - if (!complexTy) + if (!complexTy || arrAttr.size() != 2) return false; auto complexEltTy = complexTy.getElementType(); - return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && - arrAttr[1].getType() == complexEltTy; + auto re = arrAttr[0].dyn_cast(); + auto im = arrAttr[1].dyn_cast(); + return re && im && re.getType() == complexEltTy && + im.getType() == complexEltTy; } return false; } @@ -48,11 +50,14 @@ } auto complexEltTy = getType().getElementType(); - if (complexEltTy != arrayAttr[0].getType() || - complexEltTy != arrayAttr[1].getType()) { + auto re = arrayAttr[0].dyn_cast(); + auto im = arrayAttr[1].dyn_cast(); + if (!re || !im) + return emitOpError("requires attribute's elements to be float attributes"); + if (complexEltTy != re.getType() || complexEltTy != im.getType()) { return emitOpError() - << "requires attribute's element types (" << arrayAttr[0].getType() - << ", " << arrayAttr[1].getType() + << "requires attribute's element types (" << re.getType() << ", " + << im.getType() << ") to match the element type of the op's return type (" << complexEltTy << ")"; } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -86,15 +86,17 @@ if (Optional argsAttr = getArgs()) { for (Attribute arg : *argsAttr) { - if (arg.getType().isa()) { - int64_t index = arg.cast().getInt(); + auto intAttr = arg.dyn_cast(); + if (intAttr && intAttr.getType().isa()) { + int64_t index = intAttr.getInt(); // Args with elements of type index must be in range // [0..operands.size). if ((index < 0) || (index >= static_cast(getNumOperands()))) return emitOpError("index argument is out of range"); // Args with elements of type ArrayAttr must have a type. - } else if (arg.isa() && arg.getType().isa()) { + } else if (arg.isa() /*&& arg.getType().isa()*/) { + // FIXME: Array attributes never have types return emitOpError("array argument has no type"); } } @@ -102,8 +104,7 @@ if (Optional templateArgsAttr = getTemplateArgs()) { for (Attribute tArg : *templateArgsAttr) { - if (!tArg.isa() && !tArg.isa() && - !tArg.isa() && !tArg.isa()) + if (!tArg.isa()) return emitOpError("template argument has invalid type"); } } @@ -117,7 +118,7 @@ /// The constant op requires that the attribute's type matches the return type. LogicalResult emitc::ConstantOp::verify() { - Attribute value = getValueAttr(); + TypedAttr value = getValueAttr(); Type type = getType(); if (!value.getType().isa() && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() @@ -171,7 +172,7 @@ /// The variable op requires that the attribute's type matches the return type. LogicalResult emitc::VariableOp::verify() { - Attribute value = getValueAttr(); + TypedAttr value = getValueAttr(); Type type = getType(); if (!value.getType().isa() && type != value.getType()) return emitOpError() << "requires attribute's type (" << value.getType() @@ -204,7 +205,9 @@ } if (parser.parseGreater()) return Attribute(); - return get(parser.getContext(), value); + + return get(parser.getContext(), + type ? type : NoneType::get(parser.getContext()), value); } void emitc::OpaqueAttr::print(AsmPrinter &printer) const { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2409,11 +2409,16 @@ } auto arrayAttr = getValue().dyn_cast(); - if (!arrayAttr || arrayAttr.size() != 2 || - arrayAttr[0].getType() != arrayAttr[1].getType()) { + if (!arrayAttr || arrayAttr.size() != 2) { return emitOpError() << "expected array attribute with two elements, " "representing a complex constant"; } + auto re = arrayAttr[0].dyn_cast(); + auto im = arrayAttr[1].dyn_cast(); + if (!re || !im || re.getType() != im.getType()) { + return emitOpError() + << "expected array attribute with two elements of the same type"; + } Type elementType = structType.getBody()[0]; if (!elementType diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -400,8 +400,10 @@ OpBuilder builder = getBuilder(); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); - return builder.create(loc, valueAttr.getType(), - valueAttr); + Type type = NoneType::get(builder.getContext()); + if (auto typedAttr = valueAttr.dyn_cast()) + type = typedAttr.getType(); + return builder.create(loc, type, valueAttr); } Value index(int64_t dim) { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -530,7 +530,11 @@ SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { - Attribute attr = std::get<0>(it); + auto attr = std::get<0>(it).dyn_cast(); + if (!attr) { + emitOpError("expects padding values to be typed attributes"); + return DiagnosedSilenceableFailure::definiteFailure(); + } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1509,14 +1509,14 @@ return failure(); for (OpOperand *opOperand : genericOp.getInputOperands()) { Operation *def = opOperand->get().getDefiningOp(); - Attribute constantAttr; + TypedAttr constantAttr; auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { { DenseElementsAttr splatAttr; if (matchPattern(def, m_Constant(&splatAttr)) && splatAttr.isSplat() && splatAttr.getType().getElementType().isIntOrFloat()) { - constantAttr = splatAttr.getSplatValue(); + constantAttr = splatAttr.getSplatValue(); return true; } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -198,8 +198,11 @@ if (opOperand->getOperandNumber() >= paddingValues.size()) return failure(); Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; - Value paddingValue = b.create( - opToPad.getLoc(), paddingAttr.getType(), paddingAttr); + Type paddingType = b.getType(); + if (auto typedAttr = paddingAttr.dyn_cast()) + paddingType = typedAttr.getType(); + Value paddingValue = + b.create(opToPad.getLoc(), paddingType, paddingAttr); // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1309,8 +1309,8 @@ // Check that the type of the initial value is compatible with the type of // the global variable. - if (initValue.isa()) { - Type initType = initValue.getType(); + if (auto elementsAttr = initValue.dyn_cast()) { + Type initType = elementsAttr.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) return emitOpError("initial value expected to be of type ") diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -28,20 +28,15 @@ /// Returns the boolean value under the hood if the given `boolAttr` is a scalar /// or splat vector bool constant. -static Optional getScalarOrSplatBoolAttr(Attribute boolAttr) { - if (!boolAttr) +static Optional getScalarOrSplatBoolAttr(Attribute attr) { + if (!attr) return llvm::None; - auto type = boolAttr.getType(); - if (type.isInteger(1)) { - auto attr = boolAttr.cast(); - return attr.getValue(); - } - if (auto vecType = type.cast()) { - if (vecType.getElementType().isInteger(1)) - if (auto attr = boolAttr.dyn_cast()) - return attr.getSplatValue(); - } + if (auto boolAttr = attr.dyn_cast()) + return boolAttr.getValue(); + if (auto splatAttr = attr.dyn_cast()) + if (splatAttr.getElementType().isInteger(1)) + return splatAttr.getSplatValue(); return llvm::None; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1803,7 +1803,9 @@ if (parser.parseAttribute(value, kValueAttrName, state.attributes)) return failure(); - Type type = value.getType(); + Type type = NoneType::get(parser.getContext()); + if (auto typedAttr = value.dyn_cast()) + type = typedAttr.getType(); if (type.isa()) { if (parser.parseColonType(type)) return failure(); @@ -1820,15 +1822,15 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { - auto valueType = value.getType(); - if (value.isa()) { + auto valueType = value.cast().getType(); if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } if (value.isa()) { + auto valueType = value.cast().getType(); if (valueType == opType) return success(); auto arrayType = opType.dyn_cast(); @@ -1873,7 +1875,7 @@ } return success(); } - return op.emitOpError("cannot have value of type ") << valueType; + return op.emitOpError("cannot have attribute: ") << value; } LogicalResult spirv::ConstantOp::verify() { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1737,7 +1737,7 @@ if (!operands[0]) return {}; auto vectorType = getVectorType(); - if (operands[0].getType().isIntOrIndexOrFloat()) + if (operands[0].isa()) return DenseElementsAttr::get(vectorType, operands[0]); if (auto attr = operands[0].dyn_cast()) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); @@ -1855,7 +1855,7 @@ if (!lhs || !rhs) return {}; - auto lhsType = lhs.getType().cast(); + auto lhsType = lhs.cast().getType().cast(); // Only support 1-D for now to avoid complicated n-D DenseElementsAttr // manipulation. if (lhsType.getRank() != 1) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1752,7 +1752,6 @@ if (succeeded(printAlias(attr))) return; - auto attrType = attr.getType(); if (!isa(attr.getDialect())) { printDialectAttribute(attr); } else if (auto opaqueAttr = attr.dyn_cast()) { @@ -1768,7 +1767,8 @@ os << '}'; } else if (auto intAttr = attr.dyn_cast()) { - if (attrType.isSignlessInteger(1)) { + Type intType = intAttr.getType(); + if (intType.isSignlessInteger(1)) { os << (intAttr.getValue().getBoolValue() ? "true" : "false"); // Boolean integer attributes always elides the type. @@ -1779,18 +1779,18 @@ // signless 1-bit values. Indexes, signed values, and multi-bit signless // values print as signed. bool isUnsigned = - attrType.isUnsignedInteger() || attrType.isSignlessInteger(1); + intType.isUnsignedInteger() || intType.isSignlessInteger(1); intAttr.getValue().print(os, !isUnsigned); // IntegerAttr elides the type if I64. - if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64)) + if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) return; } else if (auto floatAttr = attr.dyn_cast()) { printFloatValue(floatAttr.getValue(), os); // FloatAttr elides the type if F64. - if (typeElision == AttrTypeElision::May && attrType.isF64()) + if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64()) return; } else if (auto strAttr = attr.dyn_cast()) { @@ -1892,7 +1892,7 @@ os << "[:f64"; break; } - if (denseArrayAttr.getType().cast().getRank()) + if (denseArrayAttr.getType().getRank()) os << " "; denseArrayAttr.printWithoutBraces(os); os << "]"; @@ -1902,9 +1902,14 @@ llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. - if (typeElision != AttrTypeElision::Must && !attrType.isa()) { - os << " : "; - printType(attrType); + if (typeElision != AttrTypeElision::Must) { + if (auto typedAttr = attr.dyn_cast()) { + Type attrType = typedAttr.getType(); + if (!attrType.isa()) { + os << " : "; + printType(attrType); + } + } } } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -43,9 +43,10 @@ /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { public: - DenseElementsAttributeStorage(ShapedType ty, bool isSplat) - : AttributeStorage(ty), isSplat(isSplat) {} + DenseElementsAttributeStorage(ShapedType type, bool isSplat) + : type(type), isSplat(isSplat) {} + ShapedType type; bool isSplat; }; @@ -75,7 +76,7 @@ /// Compare this storage instance with the provided key. bool operator==(const KeyTy &key) const { - if (key.type != getType()) + if (key.type != type) return false; // For boolean splats we need to explicitly check that the first bit is the @@ -228,7 +229,7 @@ /// Compare this storage instance with the provided key. bool operator==(const KeyTy &key) const { - if (key.type != getType()) + if (key.type != type) return false; // Otherwise, we can default to just checking the data. StringRefs compare @@ -324,12 +325,12 @@ struct StringAttrStorage : public AttributeStorage { StringAttrStorage(StringRef value, Type type) - : AttributeStorage(type), value(value), referencedDialect(nullptr) {} + : type(type), value(value), referencedDialect(nullptr) {} /// The hash key is a tuple of the parameter types. using KeyTy = std::pair; bool operator==(const KeyTy &key) const { - return value == key.first && getType() == key.second; + return value == key.first && type == key.second; } static ::llvm::hash_code hashKey(const KeyTy &key) { return DenseMapInfo::getHashValue(key); @@ -346,6 +347,8 @@ /// Initialize the storage given an MLIRContext. void initialize(MLIRContext *context); + /// The type of the string. + Type type; /// The raw string value. StringRef value; /// If the string value contains a dialect namespace prefix (e.g. diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -24,16 +24,12 @@ // ElementsAttr //===----------------------------------------------------------------------===// -ShapedType ElementsAttr::getType() const { - return Attribute::getType().cast(); +Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { + return elementsAttr.getType().getElementType(); } -Type ElementsAttr::getElementType(Attribute elementsAttr) { - return elementsAttr.getType().cast().getElementType(); -} - -int64_t ElementsAttr::getNumElements(Attribute elementsAttr) { - return elementsAttr.getType().cast().getNumElements(); +int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { + return elementsAttr.getType().getNumElements(); } bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { @@ -51,9 +47,9 @@ return 0 <= dim && dim < shape[i]; }); } -bool ElementsAttr::isValidIndex(Attribute elementsAttr, +bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, ArrayRef index) { - return isValidIndex(elementsAttr.getType().cast(), index); + return isValidIndex(elementsAttr.getType(), index); } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -261,6 +261,8 @@ StringRef StringAttr::getValue() const { return getImpl()->value; } +Type StringAttr::getType() const { return getImpl()->type; } + Dialect *StringAttr::getReferencedDialect() const { return getImpl()->referencedDialect; } @@ -688,29 +690,28 @@ /// Custom storage to ensure proper memory alignment for the allocation of /// DenseArray of any element type. struct mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage { - using KeyTy = std::tuple>; + using KeyTy = + std::tuple>; DenseArrayBaseAttrStorage(ShapedType type, DenseArrayBaseAttr::EltType eltType, - ::llvm::ArrayRef elements) - : AttributeStorage(type), eltType(eltType), elements(elements) {} + ArrayRef elements) + : type(type), eltType(eltType), elements(elements) {} - bool operator==(const KeyTy &tblgenKey) const { - return (getType() == std::get<0>(tblgenKey)) && - (eltType == std::get<1>(tblgenKey)) && - (elements == std::get<2>(tblgenKey)); + bool operator==(const KeyTy &key) const { + return (type == std::get<0>(key)) && (eltType == std::get<1>(key)) && + (elements == std::get<2>(key)); } - static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { - return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), - std::get<2>(tblgenKey)); + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); } static DenseArrayBaseAttrStorage * - construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) { - auto type = std::get<0>(tblgenKey); - auto eltType = std::get<1>(tblgenKey); - auto elements = std::get<2>(tblgenKey); + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + auto type = std::get<0>(key); + auto eltType = std::get<1>(key); + auto elements = std::get<2>(key); if (!elements.empty()) { char *alloc = static_cast( allocator.allocate(elements.size(), alignof(uint64_t))); @@ -721,14 +722,17 @@ DenseArrayBaseAttrStorage(type, eltType, elements); } + ShapedType type; DenseArrayBaseAttr::EltType eltType; - ::llvm::ArrayRef elements; + ArrayRef elements; }; DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const { return getImpl()->eltType; } +ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; } + const int8_t * DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); @@ -974,8 +978,8 @@ // If the element type is not based on int/float/index, assume it is a string // type. - auto eltType = type.getElementType(); - if (!type.getElementType().isIntOrIndexOrFloat()) { + Type eltType = type.getElementType(); + if (!eltType.isIntOrIndexOrFloat()) { SmallVector stringValues; stringValues.reserve(values.size()); for (Attribute attr : values) { @@ -995,14 +999,16 @@ llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); APInt intVal; for (unsigned i = 0, e = values.size(); i < e; ++i) { - assert(eltType == values[i].getType() && - "expected attribute value to have element type"); - if (eltType.isa()) - intVal = values[i].cast().getValue().bitcastToAPInt(); - else if (eltType.isa()) - intVal = values[i].cast().getValue(); - else - llvm_unreachable("unexpected element type"); + if (auto floatAttr = values[i].dyn_cast()) { + assert(floatAttr.getType() == eltType && + "expected float attribute type to equal element type"); + intVal = floatAttr.getValue().bitcastToAPInt(); + } else { + auto intAttr = values[i].cast(); + assert(intAttr.getType() == eltType && + "expected integer attribute type to equal element type"); + intVal = intAttr.getValue(); + } assert(intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type"); @@ -1010,7 +1016,7 @@ } // Handle the special encoding of splat of bool. - if (values.size() == 1 && values[0].getType().isInteger(1)) + if (values.size() == 1 && eltType.isInteger(1)) data[0] = data[0] ? -1 : 0; return DenseIntOrFPElementsAttr::getRaw(type, data); @@ -1326,7 +1332,7 @@ } ShapedType DenseElementsAttr::getType() const { - return Attribute::getType().cast(); + return static_cast(impl)->type; } Type DenseElementsAttr::getElementType() const { @@ -1546,8 +1552,9 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseFPElementsAttr::classof(Attribute attr) { - return attr.isa() && - attr.getType().cast().getElementType().isa(); + if (auto denseAttr = attr.dyn_cast()) + return denseAttr.getType().getElementType().isa(); + return false; } //===----------------------------------------------------------------------===// @@ -1564,8 +1571,9 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseIntElementsAttr::classof(Attribute attr) { - return attr.isa() && - attr.getType().cast().getElementType().isIntOrIndex(); + if (auto denseAttr = attr.dyn_cast()) + return denseAttr.getType().getElementType().isIntOrIndex(); + return false; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -896,10 +896,6 @@ MLIRContext *ctx, TypeID attrID) { storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); - - // If the attribute did not provide a type, then default to NoneType. - if (!storage->getType()) - storage->setType(NoneType::get(ctx)); } BoolAttr BoolAttr::get(MLIRContext *context, bool value) { diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -32,7 +32,9 @@ } Type mlir::getElementTypeOrSelf(Attribute attr) { - return getElementTypeOrSelf(attr.getType()); + if (auto typedAttr = attr.dyn_cast()) + return getElementTypeOrSelf(typedAttr.getType()); + return {}; } SmallVector mlir::getFlattenedTypes(TupleType t) { diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1652,7 +1652,9 @@ LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); unsigned memIndex = read(); Attribute attr = read(); - Type type = attr ? attr.getType() : Type(); + Type type; + if (auto typedAttr = attr.dyn_cast()) + type = typedAttr.getType(); LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" << " * Result: " << type << "\n"); diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -283,7 +283,8 @@ } Optional AttrOrTypeParameter::getDefaultValue() const { - return getDefValue("defaultValue"); + Optional result = getDefValue("defaultValue"); + return result && !result->empty() ? result : llvm::None; } llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -823,7 +823,7 @@ if (auto type = attr.dyn_cast()) return emitType(loc, type.getValue()); - return emitError(loc, "cannot emit attribute of type ") << attr.getType(); + return emitError(loc, "cannot emit attribute: ") << attr; } LogicalResult CppEmitter::emitOperands(Operation &op) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -769,7 +769,8 @@ // Process the type for this bool literal uint32_t typeID = 0; - if (failed(processType(loc, boolAttr.getType(), typeID))) { + if (failed( + processType(loc, boolAttr.cast().getType(), typeID))) { return 0; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -332,7 +332,7 @@ // ----- llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> { - // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + // expected-error @+1 {{expected array attribute with two elements of the same type}} %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)> llvm.return %0 : !llvm.struct<(f64, f64)> } @@ -547,7 +547,7 @@ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}} - %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] {layoutA=#nvvm.mma_layout, layoutB=#nvvm.mma_layout, shape = #nvvm.shape} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -571,7 +571,7 @@ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) { // expected-error@+1 {{op requires attribute 'layoutA'}} - %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7] {shape = #nvvm.shape}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } @@ -594,7 +594,7 @@ // expected-error@+1 {{op requires b1Op attribute}} %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, - multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, + multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -79,7 +79,7 @@ // ----- func.func @unaccepted_std_attr() -> () { - // expected-error @+1 {{cannot have value of type 'none'}} + // expected-error @+1 {{cannot have attribute: unit}} %0 = spv.Constant unit : none return } diff --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir --- a/mlir/test/IR/file-metadata-resources.mlir +++ b/mlir/test/IR/file-metadata-resources.mlir @@ -5,7 +5,7 @@ // CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000" // CHECK-NEXT: } -module attributes { test.blob_ref = #test.e1di64_elements } {} +module attributes { test.blob_ref = #test.e1di64_elements : tensor<*xi1>} {} {-# dialect_resources: { diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -53,18 +53,22 @@ } // An attribute testing AttributeSelfTypeParameter. -def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> { +def AttrWithSelfTypeParam + : Test_Attr<"AttrWithSelfTypeParam", [TypedAttrInterface]> { let mnemonic = "attr_with_self_type_param"; let parameters = (ins AttributeSelfTypeParameter<"">:$type); let assemblyFormat = ""; } // An attribute testing AttributeSelfTypeParameter. -def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> { +def AttrWithTypeBuilder + : Test_Attr<"AttrWithTypeBuilder", [TypedAttrInterface]> { let mnemonic = "attr_with_type_builder"; - let parameters = (ins "::mlir::IntegerAttr":$attr); - let typeBuilder = "$_attr.getType()"; - let hasCustomAssemblyFormat = 1; + let parameters = (ins + "::mlir::IntegerAttr":$attr, + AttributeSelfTypeParameter<"", "mlir::Type", "$attr.getType()">:$type + ); + let assemblyFormat = "$attr"; } def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">; @@ -76,7 +80,7 @@ // Test support for ElementsAttrInterface. def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ - ElementsAttrInterface + ElementsAttrInterface, TypedAttrInterface ]> { let mnemonic = "i64_elements"; let parameters = (ins @@ -215,7 +219,7 @@ // Test self type parameter with assembly format. def TestAttrSelfTypeParameterFormat - : Test_Attr<"TestAttrSelfTypeParameterFormat"> { + : Test_Attr<"TestAttrSelfTypeParameterFormat", [TypedAttrInterface]> { let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type); let mnemonic = "attr_self_type_format"; @@ -237,7 +241,7 @@ // Test simple extern 1D vector using ElementsAttrInterface. def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ - ElementsAttrInterface + ElementsAttrInterface, TypedAttrInterface ]> { let mnemonic = "e1di64_elements"; let parameters = (ins diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -27,21 +27,6 @@ using namespace mlir; using namespace test; -//===----------------------------------------------------------------------===// -// AttrWithTypeBuilderAttr -//===----------------------------------------------------------------------===// - -Attribute AttrWithTypeBuilderAttr::parse(AsmParser &parser, Type type) { - IntegerAttr element; - if (parser.parseAttribute(element)) - return Attribute(); - return get(parser.getContext(), element); -} - -void AttrWithTypeBuilderAttr::print(AsmPrinter &printer) const { - printer << " " << getAttr(); -} - //===----------------------------------------------------------------------===// // CompoundAAttr //===----------------------------------------------------------------------===// @@ -114,10 +99,11 @@ return success(); } -LogicalResult TestAttrWithFormatAttr::verify( - function_ref emitError, int64_t one, std::string two, - IntegerAttr three, ArrayRef four, - ArrayRef arrayOfAttrWithTypeBuilderAttr) { +LogicalResult +TestAttrWithFormatAttr::verify(function_ref emitError, + int64_t one, std::string two, IntegerAttr three, + ArrayRef four, + ArrayRef arrayOfAttrs) { if (four.size() != static_cast(one)) return emitError() << "expected 'one' to equal 'four.size()'"; return success(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -554,7 +554,7 @@ def ResultHasSameTypeAsAttr : TEST_Op<"result_has_same_type_as_attr", [AllTypesMatch<["attr", "result"]>]> { - let arguments = (ins AnyAttr:$attr); + let arguments = (ins TypedAttrInterface:$attr); let results = (outs AnyType:$result); let assemblyFormat = "$attr `->` type($result) attr-dict"; } @@ -2310,7 +2310,7 @@ def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ AllTypesMatch<["value1", "value2", "result"]> ]> { - let arguments = (ins AnyAttr:$value1, AnyType:$value2); + let arguments = (ins TypedAttrInterface:$value1, AnyType:$value2); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value1 `,` $value2"; } @@ -2338,7 +2338,7 @@ def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ TypesMatchWith<"result type matches constant", "value", "result", "$_self"> ]> { - let arguments = (ins AnyAttr:$value); + let arguments = (ins TypedAttrInterface:$value); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value"; } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -164,9 +164,13 @@ /// Test attribute with self type parameter -// ATTR: TestGAttr::parse -// ATTR: return TestGAttr::get -// ATTR: odsType +// ATTR-LABEL: Attribute TestGAttr::parse +// ATTR: if (odsType) +// ATTR: if (auto reqType = odsType.dyn_cast<::mlir::Type>()) +// ATTR: _result_type = reqType +// ATTR: TestGAttr::get +// ATTR-NEXT: *_result_a +// ATTR-NEXT: _result_type.value_or(::mlir::NoneType::get( def AttrD : TestAttr<"TestG"> { let parameters = (ins "int":$a, AttributeSelfTypeParameter<"">:$type); let mnemonic = "attr_d"; diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -77,11 +77,12 @@ // DECL: int getWidthOfSomething() const; // DECL: ::test::SimpleTypeA getExampleTdType() const; // DECL: ::llvm::APFloat getApFloat() const; +// DECL: ::mlir::Type getInner() const; // Check that AttributeSelfTypeParameter is handled properly. // DEF-LABEL: struct CompoundAAttrStorage // DEF: CompoundAAttrStorage( -// DEF-SAME: : ::mlir::AttributeStorage(inner), +// DEF-SAME: inner(inner) // DEF: bool operator==(const KeyTy &tblgenKey) const { // DEF-NEXT: return @@ -89,14 +90,14 @@ // DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) && // DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) && // DEF-SAME: (dims == std::get<3>(tblgenKey)) && -// DEF-SAME: (getType() == std::get<4>(tblgenKey)); +// DEF-SAME: (inner == std::get<4>(tblgenKey)); // DEF: static CompoundAAttrStorage *construct // DEF: return new (allocator.allocate()) // DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner); // DEF: ::mlir::Type CompoundAAttr::getInner() const { -// DEF-NEXT: return getImpl()->getType().cast<::mlir::Type>(); +// DEF-NEXT: return getImpl()->inner; } def C_IndexAttr : TestAttr<"Index"> { @@ -127,18 +128,6 @@ // DECL-SAME: detail::SingleParameterAttrStorage } -// An attribute testing AttributeSelfTypeParameter. -def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> { - let mnemonic = "attr_with_type_builder"; - let parameters = (ins "::mlir::IntegerAttr":$attr); - let typeBuilder = "$_attr.getType()"; - let hasCustomAssemblyFormat = 1; -} - -// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage -// DEF: AttrWithTypeBuilderAttrStorage(::mlir::IntegerAttr attr) -// DEF-SAME: : ::mlir::AttributeStorage(attr.getType()), attr(attr) - def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> { let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param); } diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -68,7 +68,7 @@ // CHECK-LABEL: OpE definitions // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) -// CHECK: odsState.addTypes({attr.getValue().getType()}); +// CHECK: odsState.addTypes({attr.getValue().cast<::mlir::TypedAttr>().getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { let results = (outs Variadic:$x); @@ -155,5 +155,5 @@ // CHECK-LABEL: LogicalResult OpL3::inferReturnTypes // CHECK-NOT: } -// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType(); +// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -295,11 +295,7 @@ // class. Otherwise, let the user define the exact accessor definition. if (!def.genStorageClass()) continue; - auto scope = m->body().indent().scope("return getImpl()->", ";"); - if (isa(param)) - m->body() << formatv("getType().cast<{0}>()", param.getCppType()); - else - m->body() << param.getName(); + m->body().indent() << "return getImpl()->" << param.getName() << ";"; } } @@ -450,37 +446,8 @@ void DefGen::emitStorageConstructor() { Constructor *ctor = storageCls->addConstructor(getBuilderParams({})); - if (auto *attrDef = dyn_cast(&def)) { - // For attributes, a parameter marked with AttributeSelfTypeParameter is - // the type initializer that must be passed to the parent constructor. - const auto isSelfType = [](const AttrOrTypeParameter ¶m) { - return isa(param); - }; - auto *selfTypeParam = llvm::find_if(params, isSelfType); - if (std::count_if(selfTypeParam, params.end(), isSelfType) > 1) { - PrintFatalError(def.getLoc(), - "Only one attribute parameter can be marked as " - "AttributeSelfTypeParameter"); - } - // Alternatively, if a type builder was specified, use that instead. - std::string attrStorageInit = - selfTypeParam == params.end() ? "" : selfTypeParam->getName().str(); - if (attrDef->getTypeBuilder()) { - FmtContext ctx; - for (auto ¶m : params) - ctx.addSubst(strfmt("_{0}", param.getName()), param.getName()); - attrStorageInit = tgfmt(*attrDef->getTypeBuilder(), &ctx); - } - ctor->addMemberInitializer("::mlir::AttributeStorage", - std::move(attrStorageInit)); - // Initialize members that aren't the attribute's type. - for (auto ¶m : params) - if (selfTypeParam == params.end() || *selfTypeParam != param) - ctor->addMemberInitializer(param.getName(), param.getName()); - } else { - for (auto ¶m : params) - ctor->addMemberInitializer(param.getName(), param.getName()); - } + for (auto ¶m : params) + ctor->addMemberInitializer(param.getName(), param.getName()); } void DefGen::emitKeyType() { @@ -498,9 +465,7 @@ auto &body = eq->body().indent(); auto scope = body.scope("return (", ");"); const auto eachFn = [&](auto it) { - FmtContext ctx({{"_lhs", isa(it.value()) - ? "getType()" - : it.value().getName()}, + FmtContext ctx({{"_lhs", it.value().getName()}, {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}}); body << tgfmt(it.value().getComparator(), &ctx); }; @@ -566,8 +531,7 @@ // Emit the storage class members as public, at the very end of the struct. storageCls->finalize(); for (auto ¶m : params) - if (!isa(param)) - storageCls->declare(param.getCppType(), param.getName()); + storageCls->declare(param.getCppType(), param.getName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -246,6 +246,39 @@ // ParserGen //===----------------------------------------------------------------------===// +/// Generate a special-case "parser" for an attribute's self type parameter. The +/// self type parameter has special handling in the assembly format in that it +/// is derived from the optional trailing colon type after the attribute. +static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx, + const AttributeSelfTypeParameter ¶m) { + // "Parser" for an attribute self type parameter that checks the + // optionally-parsed trailing colon type. + // + // $0: The C++ storage class of the type parameter. + // $1: The self type parameter name. + const char *const selfTypeParser = R"( +if ($_type) { + if (auto reqType = $_type.dyn_cast<$0>()) { + _result_$1 = reqType; + } else { + $_parser.emitError($_loc, "invalid kind of type specified"); + return {}; + } +})"; + + // If the attribute self type parameter is required, emit code that emits an + // error if the trailing type was not parsed. + const char *const selfTypeRequired = R"( else { + $_parser.emitError($_loc, "expected a trailing type"); + return {}; +})"; + + os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName()); + if (!param.isOptional()) + os << tgfmt(selfTypeRequired, &ctx); + os << "\n"; +} + void DefFormat::genParser(MethodBody &os) { FmtContext ctx; ctx.addSubst("_parser", "odsParser"); @@ -262,8 +295,6 @@ // a loop (parsers return FailureOr anyways). ArrayRef params = def.getParameters(); for (const AttrOrTypeParameter ¶m : params) { - if (isa(param)) - continue; os << formatv("::mlir::FailureOr<{0}> _result_{1};\n", param.getCppStorageType(), param.getName()); } @@ -281,7 +312,9 @@ // Emit an assert for each mandatory parameter. Triggering an assert means // the generated parser is incorrect (i.e. there is a bug in this code). for (const AttrOrTypeParameter ¶m : params) { - if (param.isOptional() || isa(param)) + if (auto *selfTypeParam = dyn_cast(¶m)) + genAttrSelfTypeParser(os, ctx, *selfTypeParam); + if (param.isOptional()) continue; os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName()); } @@ -306,11 +339,10 @@ else selfOs << param.getCppStorageType() << "()"; selfOs << "))"; - } else if (isa(param)) { - selfOs << tgfmt("$_type", &ctx); } else { selfOs << formatv("(*_result_{0})", param.getName()); } + ctx.addSubst(param.getName(), selfOs.str()); os << param.getCppType() << "(" << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())) << ")"; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -578,7 +578,8 @@ // Populate substitutions for attributes. auto &op = emitHelper.getOp(); for (const auto &namedAttr : op.getAttributes()) - ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str()); + ctx.addSubst(namedAttr.name, + emitHelper.getOp().getGetterName(namedAttr.name) + "()"); // Populate substitutions for named operands. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { @@ -1756,7 +1757,7 @@ if (namedAttr.attr.isTypeAttr()) { resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()"; } else { - resultType = "attr.getValue().getType()"; + resultType = "attr.getValue().cast<::mlir::TypedAttr>().getType()"; } // Operands @@ -2416,7 +2417,8 @@ } else { auto *attr = op.getArg(arg.operandOrAttributeIndex()).get(); - body << "attributes.get(\"" << attr->name << "\").getType()"; + body << "attributes.get(\"" << attr->name + << "\").cast<::mlir::TypedAttr>().getType()"; } body << ";\n"; } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -237,16 +237,19 @@ // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. - auto zeroIntValue = sparseInt.getValues()[{1, 1}]; - EXPECT_EQ(zeroIntValue.cast().getInt(), 0); + auto zeroIntValue = + sparseInt.getValues()[{1, 1}].cast(); + EXPECT_EQ(zeroIntValue.getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); - auto zeroFloatValue = sparseFloat.getValues()[{1, 1}]; - EXPECT_EQ(zeroFloatValue.cast().getValueAsDouble(), 0.0f); + auto zeroFloatValue = + sparseFloat.getValues()[{1, 1}].cast(); + EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); - auto zeroStringValue = sparseString.getValues()[{1, 1}]; - EXPECT_TRUE(zeroStringValue.cast().getValue().empty()); + auto zeroStringValue = + sparseString.getValues()[{1, 1}].cast(); + EXPECT_TRUE(zeroStringValue.getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); }