diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md --- a/mlir/docs/CAPI.md +++ b/mlir/docs/CAPI.md @@ -185,12 +185,12 @@ ### Extensions for Dialect Attributes and Types -Dialect attributes and types can follow the example of standard attributes and +Dialect attributes and types can follow the example of builtin attributes and types, provided that implementations live in separate directories, i.e. `include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs provide implementation-private headers in `include/mlir/CAPI/IR` that allow one to convert between opaque C structures for core IR components and their C++ counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does -the inverse conversion. Once the C++ object is available, the API -implementation should rely on `isa` to implement `mlirXIsAY` and is expected to -use `cast` inside other API calls. +the inverse conversion. Once the C++ object is available, the API implementation +should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast` +inside other API calls. diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -1337,7 +1337,7 @@ Attribute values are represented by the following forms: ``` -attribute-value ::= attribute-alias | dialect-attribute | standard-attribute +attribute-value ::= attribute-alias | dialect-attribute | builtin-attribute ``` ### Attribute Value Aliases @@ -1401,17 +1401,17 @@ that are not allowed in the lighter syntax, as well as unbalanced `<>` characters. -See [here](Tutorials/DefiningAttributesAndTypes.md) to learn how to define dialect -attribute values. +See [here](Tutorials/DefiningAttributesAndTypes.md) to learn how to define +dialect attribute values. -### Standard Attribute Values +### Builtin Attribute Values -Standard attributes are a core set of +Builtin attributes are a core set of [dialect attributes](#dialect-attribute-values) that are defined in a builtin dialect and thus available to all users of MLIR. ``` -standard-attribute ::= affine-map-attribute +builtin-attribute ::= affine-map-attribute | array-attribute | bool-attribute | dictionary-attribute diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h rename from mlir/include/mlir-c/StandardAttributes.h rename to mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/StandardAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -1,4 +1,4 @@ -//===-- mlir-c/StandardAttributes.h - C API for Std Attributes-----*- C -*-===// +//===-- mlir-c/BuiltinAttributes.h - C API for Builtin Attributes -*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// This header declares the C interface to MLIR Standard attributes. +// This header declares the C interface to MLIR Builtin attributes. // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_STANDARDATTRIBUTES_H -#define MLIR_C_STANDARDATTRIBUTES_H +#ifndef MLIR_C_BUILTINATTRIBUTES_H +#define MLIR_C_BUILTINATTRIBUTES_H #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" @@ -45,9 +45,8 @@ /** Creates an array element containing the given list of elements in the given * context. */ -MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet(MlirContext ctx, - intptr_t numElements, - MlirAttribute const *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGet( + MlirContext ctx, intptr_t numElements, MlirAttribute const *elements); /// Returns the number of elements stored in the given array attribute. MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr); @@ -446,4 +445,4 @@ } #endif -#endif // MLIR_C_STANDARDATTRIBUTES_H +#endif // MLIR_C_BUILTINATTRIBUTES_H diff --git a/mlir/include/mlir/Bindings/Python/Attributes.td b/mlir/include/mlir/Bindings/Python/Attributes.td --- a/mlir/include/mlir/Bindings/Python/Attributes.td +++ b/mlir/include/mlir/Bindings/Python/Attributes.td @@ -15,13 +15,13 @@ #define PYTHON_BINDINGS_ATTRIBUTES // A mapping between the attribute storage type and the corresponding Python -// type. There is not necessarily a 1-1 match for non-standard attributes. +// type. There is not necessarily a 1-1 match for non-builtin attributes. class PythonAttr { string cppStorageType = c; string pythonType = p; } -// Mappings between supported standard attribtues and Python types. +// Mappings between supported builtin attribtues and Python types. def : PythonAttr<"::mlir::Attribute", "_ir.Attribute">; def : PythonAttr<"::mlir::BoolAttr", "_ir.BoolAttr">; def : PythonAttr<"::mlir::IntegerAttr", "_ir.IntegerAttr">; diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -15,7 +15,7 @@ #ifndef MLIR_DIALECT_COMMONFOLDERS_H #define MLIR_DIALECT_COMMONFOLDERS_H -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h --- a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h +++ b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h @@ -14,7 +14,7 @@ #ifndef MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H #define MLIR_DIALECT_GPU_PARALLELLOOPMAPPER_H -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -14,7 +14,7 @@ #define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H #include "mlir/Dialect/SPIRV/SPIRVTypes.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" // Pull in SPIR-V attribute definitions for target and ABI. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -18,7 +18,7 @@ #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -10,47 +10,16 @@ #define MLIR_IR_ATTRIBUTES_H #include "mlir/IR/AttributeSupport.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/Sequence.h" #include "llvm/Support/PointerLikeTypeTraits.h" -#include namespace mlir { -class AffineMap; -class Dialect; -class FunctionType; class Identifier; -class IntegerSet; -class Location; -class MLIRContext; -class ShapedType; -class Type; -namespace detail { - -struct AffineMapAttributeStorage; -struct ArrayAttributeStorage; -struct DictionaryAttributeStorage; -struct IntegerAttributeStorage; -struct IntegerSetAttributeStorage; -struct FloatAttributeStorage; -struct OpaqueAttributeStorage; -struct StringAttributeStorage; -struct SymbolRefAttributeStorage; -struct TypeAttributeStorage; - -/// Elements Attributes. -struct DenseIntOrFPElementsAttributeStorage; -struct DenseStringElementsAttributeStorage; -struct OpaqueElementsAttributeStorage; -struct SparseElementsAttributeStorage; -} // namespace detail - -/// Attributes are known-constant values of operations and functions. +/// Attributes are known-constant values of operations. /// /// Instances of the Attribute class are references to immortal key-value pairs -/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a -/// thin wrapper around an underlying storage pointer. Attributes are usually +/// with immutable, uniqued keys owned by MLIRContext. As such, an Attribute is +/// a thin wrapper around an underlying storage pointer. Attributes are usually /// passed by value. class Attribute { public: @@ -126,1469 +95,6 @@ return os; } -//===----------------------------------------------------------------------===// -// AttributeTraitBase -//===----------------------------------------------------------------------===// - -namespace AttributeTrait { -/// This class represents the base of an attribute trait. -template class TraitType> -using TraitBase = detail::StorageUserTraitBase; -} // namespace AttributeTrait - -//===----------------------------------------------------------------------===// -// AttributeInterface -//===----------------------------------------------------------------------===// - -/// This class represents the base of an attribute interface. See the definition -/// of `detail::Interface` for requirements on the `Traits` type. -template -class AttributeInterface - : public detail::Interface { -public: - using Base = AttributeInterface; - using InterfaceBase = detail::Interface; - using InterfaceBase::InterfaceBase; - -private: - /// Returns the impl interface instance for the given type. - static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { - return attr.getAbstractAttribute().getInterface(); - } - - /// Allow access to 'getInterfaceFor'. - friend InterfaceBase; -}; - -//===----------------------------------------------------------------------===// -// AffineMapAttr -//===----------------------------------------------------------------------===// - -class AffineMapAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = AffineMap; - - static AffineMapAttr get(AffineMap value); - - AffineMap getValue() const; -}; - -//===----------------------------------------------------------------------===// -// ArrayAttr -//===----------------------------------------------------------------------===// - -/// Array attributes are lists of other attributes. They are not necessarily -/// type homogenous given that attributes don't, in general, carry types. -class ArrayAttr : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = ArrayRef; - - static ArrayAttr get(ArrayRef value, MLIRContext *context); - - ArrayRef getValue() const; - Attribute operator[](unsigned idx) const; - - /// Support range iteration. - using iterator = llvm::ArrayRef::iterator; - iterator begin() const { return getValue().begin(); } - iterator end() const { return getValue().end(); } - size_t size() const { return getValue().size(); } - bool empty() const { return size() == 0; } - -private: - /// Class for underlying value iterator support. - template - class attr_value_iterator final - : public llvm::mapped_iterator { - public: - explicit attr_value_iterator(ArrayAttr::iterator it) - : llvm::mapped_iterator( - it, [](Attribute attr) { return attr.cast(); }) {} - AttrTy operator*() const { return (*this->I).template cast(); } - }; - -public: - template - iterator_range> getAsRange() { - return llvm::make_range(attr_value_iterator(begin()), - attr_value_iterator(end())); - } - template - auto getAsValueRange() { - return llvm::map_range(getAsRange(), [](AttrTy attr) { - return static_cast(attr.getValue()); - }); - } -}; - -//===----------------------------------------------------------------------===// -// DictionaryAttr -//===----------------------------------------------------------------------===// - -/// NamedAttribute is used for dictionary attributes, it holds an identifier for -/// the name and a value for the attribute. The attribute pointer should always -/// be non-null. -using NamedAttribute = std::pair; - -bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); -bool operator<(const NamedAttribute &lhs, StringRef rhs); - -/// Dictionary attribute is an attribute that represents a sorted collection of -/// named attribute values. The elements are sorted by name, and each name must -/// be unique within the collection. -class DictionaryAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = ArrayRef; - - /// Construct a dictionary attribute with the provided list of named - /// attributes. This method assumes that the provided list is unordered. If - /// the caller can guarantee that the attributes are ordered by name, - /// getWithSorted should be used instead. - static DictionaryAttr get(ArrayRef value, - MLIRContext *context); - - /// Construct a dictionary with an array of values that is known to already be - /// sorted by name and uniqued. - static DictionaryAttr getWithSorted(ArrayRef value, - MLIRContext *context); - - ArrayRef getValue() const; - - /// Return the specified attribute if present, null otherwise. - Attribute get(StringRef name) const; - Attribute get(Identifier name) const; - - /// Return the specified named attribute if present, None otherwise. - Optional getNamed(StringRef name) const; - Optional getNamed(Identifier name) const; - - /// Support range iteration. - using iterator = llvm::ArrayRef::iterator; - iterator begin() const; - iterator end() const; - bool empty() const { return size() == 0; } - size_t size() const; - - /// Sorts the NamedAttributes in the array ordered by name as expected by - /// getWithSorted and returns whether the values were sorted. - /// Requires: uniquely named attributes. - static bool sort(ArrayRef values, - SmallVectorImpl &storage); - - /// Sorts the NamedAttributes in the array ordered by name as expected by - /// getWithSorted in place on an array and returns whether the values needed - /// to be sorted. - /// Requires: uniquely named attributes. - static bool sortInPlace(SmallVectorImpl &array); - - /// Returns an entry with a duplicate name in `array`, if it exists, else - /// returns llvm::None. If `isSorted` is true, the array is assumed to be - /// sorted else it will be sorted in place before finding the duplicate entry. - static Optional - findDuplicate(SmallVectorImpl &array, bool isSorted); - -private: - /// Return empty dictionary. - static DictionaryAttr getEmpty(MLIRContext *context); -}; - -//===----------------------------------------------------------------------===// -// FloatAttr -//===----------------------------------------------------------------------===// - -class FloatAttr : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = APFloat; - - /// Return a float attribute for the specified value in the specified type. - /// These methods should only be used for simple constant values, e.g 1.0/2.0, - /// that are known-valid both as host double and the 'type' format. - static FloatAttr get(Type type, double value); - static FloatAttr getChecked(Type type, double value, Location loc); - - /// Return a float attribute for the specified value in the specified type. - static FloatAttr get(Type type, const APFloat &value); - static FloatAttr getChecked(Type type, const APFloat &value, Location loc); - - APFloat getValue() const; - - /// This function is used to convert the value to a double, even if it loses - /// precision. - double getValueAsDouble() const; - static double getValueAsDouble(APFloat val); - - /// Verify the construction invariants for a double value. - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - double value); - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - const APFloat &value); -}; - -//===----------------------------------------------------------------------===// -// IntegerAttr -//===----------------------------------------------------------------------===// - -class IntegerAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = APInt; - - static IntegerAttr get(Type type, int64_t value); - static IntegerAttr get(Type type, const APInt &value); - - APInt getValue() const; - /// Return the integer value as a 64-bit int. The attribute must be a signless - /// integer. - // TODO: Change callers to use getValue instead. - int64_t getInt() const; - /// Return the integer value as a signed 64-bit int. The attribute must be - /// a signed integer. - int64_t getSInt() const; - /// Return the integer value as a unsigned 64-bit int. The attribute must be - /// an unsigned integer. - uint64_t getUInt() const; - - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - int64_t value); - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - const APInt &value); -}; - -//===----------------------------------------------------------------------===// -// BoolAttr - -/// Special case of IntegerAttr to represent boolean integers, i.e., signless i1 -/// integers. -class BoolAttr : public Attribute { -public: - using Attribute::Attribute; - using ValueType = bool; - - static BoolAttr get(bool value, MLIRContext *context); - - /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to - /// avoid bringing in all of IntegerAttrs methods. - operator IntegerAttr() const { return IntegerAttr(impl); } - - /// Return the boolean value of this attribute. - bool getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Attribute attr); -}; - -//===----------------------------------------------------------------------===// -// IntegerSetAttr -//===----------------------------------------------------------------------===// - -class IntegerSetAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = IntegerSet; - - static IntegerSetAttr get(IntegerSet value); - - IntegerSet getValue() const; -}; - -//===----------------------------------------------------------------------===// -// OpaqueAttr -//===----------------------------------------------------------------------===// - -/// Opaque attributes represent attributes of non-registered dialects. These are -/// attribute represented in their raw string form, and can only usefully be -/// tested for attribute equality. -class OpaqueAttr : public Attribute::AttrBase { -public: - using Base::Base; - - /// Get or create a new OpaqueAttr with the provided dialect and string data. - static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context); - - /// Get or create a new OpaqueAttr with the provided dialect and string data. - /// If the given identifier is not a valid namespace for a dialect, then a - /// null attribute is returned. - static OpaqueAttr getChecked(Identifier dialect, StringRef attrData, - Type type, Location location); - - /// Returns the dialect namespace of the opaque attribute. - Identifier getDialectNamespace() const; - - /// Returns the raw attribute data of the opaque attribute. - StringRef getAttrData() const; - - /// Verify the construction of an opaque attribute. - static LogicalResult verifyConstructionInvariants(Location loc, - Identifier dialect, - StringRef attrData, - Type type); -}; - -//===----------------------------------------------------------------------===// -// StringAttr -//===----------------------------------------------------------------------===// - -class StringAttr : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = StringRef; - - /// Get an instance of a StringAttr with the given string. - static StringAttr get(StringRef bytes, MLIRContext *context); - - /// Get an instance of a StringAttr with the given string and Type. - static StringAttr get(StringRef bytes, Type type); - - StringRef getValue() const; -}; - -//===----------------------------------------------------------------------===// -// SymbolRefAttr -//===----------------------------------------------------------------------===// - -class FlatSymbolRefAttr; - -/// A symbol reference attribute represents a symbolic reference to another -/// operation. -class SymbolRefAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); - - /// Construct a symbol reference for the given value name, and a set of nested - /// references that are further resolve to a nested symbol. - static SymbolRefAttr get(StringRef value, - ArrayRef references, - MLIRContext *ctx); - - /// Returns the name of the top level symbol reference, i.e. the root of the - /// reference path. - StringRef getRootReference() const; - - /// Returns the name of the fully resolved symbol, i.e. the leaf of the - /// reference path. - StringRef getLeafReference() const; - - /// Returns the set of nested references representing the path to the symbol - /// nested under the root reference. - ArrayRef getNestedReferences() const; -}; - -/// A symbol reference with a reference path containing a single element. This -/// is used to refer to an operation within the current symbol table. -class FlatSymbolRefAttr : public SymbolRefAttr { -public: - using SymbolRefAttr::SymbolRefAttr; - using ValueType = StringRef; - - /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { - return SymbolRefAttr::get(value, ctx); - } - - /// Returns the name of the held symbol reference. - StringRef getValue() const { return getRootReference(); } - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Attribute attr) { - SymbolRefAttr refAttr = attr.dyn_cast(); - return refAttr && refAttr.getNestedReferences().empty(); - } - -private: - using SymbolRefAttr::get; - using SymbolRefAttr::getNestedReferences; -}; - -//===----------------------------------------------------------------------===// -// Type -//===----------------------------------------------------------------------===// - -class TypeAttr : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = Type; - - static TypeAttr get(Type value); - - Type getValue() const; -}; - -//===----------------------------------------------------------------------===// -// UnitAttr -//===----------------------------------------------------------------------===// - -/// Unit attributes are attributes that hold no specific value and are given -/// meaning by their existence. -class UnitAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - static UnitAttr get(MLIRContext *context); -}; - -//===----------------------------------------------------------------------===// -// Elements Attributes -//===----------------------------------------------------------------------===// - -namespace detail { -template class ElementsAttrIterator; -template class ElementsAttrRange; -} // namespace detail - -/// A base attribute that represents a reference to a static shaped tensor or -/// vector constant. -class ElementsAttr : public Attribute { -public: - using Attribute::Attribute; - template using iterator = detail::ElementsAttrIterator; - template using iterator_range = detail::ElementsAttrRange; - - /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor - /// with static shape. - ShapedType getType() const; - - /// Return the value at the given index. The index is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const; - - /// Return the value of type 'T' at the given index, where 'T' corresponds to - /// an Attribute type. - template T getValue(ArrayRef index) const { - return getValue(index).template cast(); - } - - /// Return the elements of this attribute as a value of type 'T'. Note: - /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support - /// iteration. - template iterator_range getValues() const; - - /// Return if the given 'index' refers to a valid element in this attribute. - bool isValidIndex(ArrayRef index) const; - - /// Returns the number of elements held by this attribute. - int64_t getNumElements() const; - - /// Returns the number of elements held by this attribute. - int64_t size() const { return getNumElements(); } - - /// Generates a new ElementsAttr by mapping each int value to a new - /// underlying APInt. The new values can represent either an integer or float. - /// This ElementsAttr should contain integers. - ElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Generates a new ElementsAttr by mapping each float value to a new - /// underlying APInt. The new values can represent either an integer or float. - /// This ElementsAttr should contain floats. - ElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr); - -protected: - /// Returns the 1 dimensional flattened row-major index from the given - /// multi-dimensional index. - uint64_t getFlattenedIndex(ArrayRef index) const; -}; - -namespace detail { -/// DenseElementsAttr data is aligned to uint64_t, so this traits class is -/// necessary to interop with PointerIntPair. -class DenseElementDataPointerTypeTraits { -public: - static inline const void *getAsVoidPointer(const char *ptr) { return ptr; } - static inline const char *getFromVoidPointer(const void *ptr) { - return static_cast(ptr); - } - - // Note: We could steal more bits if the need arises. - static constexpr int NumLowBitsAvailable = 1; -}; - -/// Pair of raw pointer and a boolean flag of whether the pointer holds a splat, -using DenseIterPtrAndSplat = - llvm::PointerIntPair; - -/// Impl iterator for indexed DenseElementsAttr iterators that records a data -/// pointer and data index that is adjusted for the case of a splat attribute. -template -class DenseElementIndexedIteratorImpl - : public llvm::indexed_accessor_iterator { -protected: - DenseElementIndexedIteratorImpl(const char *data, bool isSplat, - size_t dataIndex) - : llvm::indexed_accessor_iterator({data, isSplat}, - dataIndex) {} - - /// Return the current index for this iterator, adjusted for the case of a - /// splat. - ptrdiff_t getDataIndex() const { - bool isSplat = this->base.getInt(); - return isSplat ? 0 : this->index; - } - - /// Return the data base pointer. - const char *getData() const { return this->base.getPointer(); } -}; - -/// Type trait detector that checks if a given type T is a complex type. -template struct is_complex_t : public std::false_type {}; -template -struct is_complex_t> : public std::true_type {}; -} // namespace detail - -/// An attribute that represents a reference to a dense vector or tensor object. -/// -class DenseElementsAttr : public ElementsAttr { -public: - using ElementsAttr::ElementsAttr; - - /// Type trait used to check if the given type T is a potentially valid C++ - /// floating point type that can be used to access the underlying element - /// types of a DenseElementsAttr. - // TODO: Use std::disjunction when C++17 is supported. - template struct is_valid_cpp_fp_type { - /// The type is a valid floating point type if it is a builtin floating - /// point type, or is a potentially user defined floating point type. The - /// latter allows for supporting users that have custom types defined for - /// bfloat16/half/etc. - static constexpr bool value = llvm::is_one_of::value || - (std::numeric_limits::is_specialized && - !std::numeric_limits::is_integer); - }; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr); - - /// Constructs a dense elements attribute from an array of element values. - /// Each element attribute value is expected to be an element of 'type'. - /// 'type' must be a vector or tensor with static shape. If the element of - /// `type` is non-integer/index/float it is assumed to be a string type. - static DenseElementsAttr get(ShapedType type, ArrayRef values); - - /// Constructs a dense integer elements attribute from an array of integer - /// or floating-point values. Each value is expected to be the same bitwidth - /// of the element type of 'type'. 'type' must be a vector or tensor with - /// static shape. - template ::is_integer || - is_valid_cpp_fp_type::value>::type> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { - const char *data = reinterpret_cast(values.data()); - return getRawIntOrFloat( - type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), - std::numeric_limits::is_integer, std::numeric_limits::is_signed); - } - - /// Constructs a dense integer elements attribute from a single element. - template ::is_integer || - is_valid_cpp_fp_type::value || - detail::is_complex_t::value>::type> - static DenseElementsAttr get(const ShapedType &type, T value) { - return get(type, llvm::makeArrayRef(value)); - } - - /// Constructs a dense complex elements attribute from an array of complex - /// values. Each value is expected to be the same bitwidth of the element type - /// of 'type'. 'type' must be a vector or tensor with static shape. - template ::value && - (std::numeric_limits::is_integer || - is_valid_cpp_fp_type::value)>::type> - static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { - const char *data = reinterpret_cast(values.data()); - return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), - sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed); - } - - /// Overload of the above 'get' method that is specialized for boolean values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); - - /// Overload of the above 'get' method that is specialized for StringRef - /// values. - static DenseElementsAttr get(ShapedType type, ArrayRef values); - - /// Constructs a dense integer elements attribute from an array of APInt - /// values. Each APInt value is expected to have the same bitwidth as the - /// element type of 'type'. 'type' must be a vector or tensor with static - /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); - - /// Constructs a dense complex elements attribute from an array of APInt - /// values. Each APInt value is expected to have the same bitwidth as the - /// element type of 'type'. 'type' must be a vector or tensor with static - /// shape. - static DenseElementsAttr get(ShapedType type, - ArrayRef> values); - - /// Constructs a dense float elements attribute from an array of APFloat - /// values. Each APFloat value is expected to have the same bitwidth as the - /// element type of 'type'. 'type' must be a vector or tensor with static - /// shape. - static DenseElementsAttr get(ShapedType type, ArrayRef values); - - /// Constructs a dense complex elements attribute from an array of APFloat - /// values. Each APFloat value is expected to have the same bitwidth as the - /// element type of 'type'. 'type' must be a vector or tensor with static - /// shape. - static DenseElementsAttr get(ShapedType type, - ArrayRef> values); - - /// Construct a dense elements attribute for an initializer_list of values. - /// Each value is expected to be the same bitwidth of the element type of - /// 'type'. 'type' must be a vector or tensor with static shape. - template - static DenseElementsAttr get(const ShapedType &type, - const std::initializer_list &list) { - return get(type, ArrayRef(list)); - } - - /// Construct a dense elements attribute from a raw buffer representing the - /// data for this attribute. Users should generally not use this methods as - /// the expected buffer format may not be a form the user expects. - static DenseElementsAttr getFromRawBuffer(ShapedType type, - ArrayRef rawBuffer, - bool isSplatBuffer); - - /// Returns true if the given buffer is a valid raw buffer for the given type. - /// `detectedSplat` is set if the buffer is valid and represents a splat - /// buffer. - static bool isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, - bool &detectedSplat); - - //===--------------------------------------------------------------------===// - // Iterators - //===--------------------------------------------------------------------===// - - /// A utility iterator that allows walking over the internal Attribute values - /// of a DenseElementsAttr. - class AttributeElementIterator - : public llvm::indexed_accessor_iterator { - public: - /// Accesses the Attribute value at this iterator position. - Attribute operator*() const; - - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - AttributeElementIterator(DenseElementsAttr attr, size_t index); - }; - - /// Iterator for walking raw element values of the specified type 'T', which - /// may be any c++ data type matching the stored representation: int32_t, - /// float, etc. - template - class ElementIterator - : public detail::DenseElementIndexedIteratorImpl, - const T> { - public: - /// Accesses the raw value at this iterator position. - const T &operator*() const { - return reinterpret_cast(this->getData())[this->getDataIndex()]; - } - - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - ElementIterator(const char *data, bool isSplat, size_t dataIndex) - : detail::DenseElementIndexedIteratorImpl, const T>( - data, isSplat, dataIndex) {} - }; - - /// A utility iterator that allows walking over the internal bool values. - class BoolElementIterator - : public detail::DenseElementIndexedIteratorImpl { - public: - /// Accesses the bool value at this iterator position. - bool operator*() const; - - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - BoolElementIterator(DenseElementsAttr attr, size_t dataIndex); - }; - - /// A utility iterator that allows walking over the internal raw APInt values. - class IntElementIterator - : public detail::DenseElementIndexedIteratorImpl { - public: - /// Accesses the raw APInt value at this iterator position. - APInt operator*() const; - - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - IntElementIterator(DenseElementsAttr attr, size_t dataIndex); - - /// The bitwidth of the element type. - size_t bitWidth; - }; - - /// A utility iterator that allows walking over the internal raw complex APInt - /// values. - class ComplexIntElementIterator - : public detail::DenseElementIndexedIteratorImpl< - ComplexIntElementIterator, std::complex, std::complex, - std::complex> { - public: - /// Accesses the raw std::complex value at this iterator position. - std::complex operator*() const; - - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex); - - /// The bitwidth of the element type. - size_t bitWidth; - }; - - /// Iterator for walking over APFloat values. - class FloatElementIterator final - : public llvm::mapped_iterator> { - friend DenseElementsAttr; - - /// Initializes the float element iterator to the specified iterator. - FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it); - - public: - using reference = APFloat; - }; - - /// Iterator for walking over complex APFloat values. - class ComplexFloatElementIterator final - : public llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>> { - friend DenseElementsAttr; - - /// Initializes the float element iterator to the specified iterator. - ComplexFloatElementIterator(const llvm::fltSemantics &smt, - ComplexIntElementIterator it); - - public: - using reference = std::complex; - }; - - //===--------------------------------------------------------------------===// - // Value Querying - //===--------------------------------------------------------------------===// - - /// Returns true if this attribute corresponds to a splat, i.e. if all element - /// values are the same. - bool isSplat() const; - - /// Return the splat value for this attribute. This asserts that the attribute - /// corresponds to a splat. - Attribute getSplatValue() const { return getSplatValue(); } - template - typename std::enable_if::value || - std::is_same::value, - T>::type - getSplatValue() const { - assert(isSplat() && "expected the attribute to be a splat"); - return *getValues().begin(); - } - /// Return the splat value for derived attribute element types. - template - typename std::enable_if::value && - !std::is_same::value, - T>::type - getSplatValue() const { - return getSplatValue().template cast(); - } - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - template T getValue(ArrayRef index) const { - // Skip to the element corresponding to the flattened index. - return *std::next(getValues().begin(), getFlattenedIndex(index)); - } - - /// Return the held element values as a range of integer or floating-point - /// values. - template ::value && - std::numeric_limits::is_integer) || - is_valid_cpp_fp_type::value>::type> - llvm::iterator_range> getValues() const { - assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - const char *rawData = getRawData().data(); - bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), - ElementIterator(rawData, splat, getNumElements())}; - } - - /// Return the held element values as a range of std::complex. - template ::value && - (std::numeric_limits::is_integer || - is_valid_cpp_fp_type::value)>::type> - llvm::iterator_range> getValues() const { - assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - const char *rawData = getRawData().data(); - bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), - ElementIterator(rawData, splat, getNumElements())}; - } - - /// Return the held element values as a range of StringRef. - template ::value>::type> - llvm::iterator_range> getValues() const { - auto stringRefs = getRawStringData(); - const char *ptr = reinterpret_cast(stringRefs.data()); - bool splat = isSplat(); - return {ElementIterator(ptr, splat, 0), - ElementIterator(ptr, splat, getNumElements())}; - } - - /// Return the held element values as a range of Attributes. - llvm::iterator_range getAttributeValues() const; - template ::value>::type> - llvm::iterator_range getValues() const { - return getAttributeValues(); - } - AttributeElementIterator attr_value_begin() const; - AttributeElementIterator attr_value_end() const; - - /// Return the held element values a range of T, where T is a derived - /// attribute type. - template - using DerivedAttributeElementIterator = - llvm::mapped_iterator; - template ::value && - !std::is_same::value>::type> - llvm::iterator_range> getValues() const { - auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getAttributeValues(), - static_cast(castFn)); - } - - /// Return the held element values as a range of bool. The element type of - /// this attribute must be of integer type of bitwidth 1. - llvm::iterator_range getBoolValues() const; - template ::value>::type> - llvm::iterator_range getValues() const { - return getBoolValues(); - } - - /// Return the held element values as a range of APInts. The element type of - /// this attribute must be of integer type. - llvm::iterator_range getIntValues() const; - template ::value>::type> - llvm::iterator_range getValues() const { - return getIntValues(); - } - IntElementIterator int_value_begin() const; - IntElementIterator int_value_end() const; - - /// Return the held element values as a range of complex APInts. The element - /// type of this attribute must be a complex of integer type. - llvm::iterator_range getComplexIntValues() const; - template >::value>::type> - llvm::iterator_range getValues() const { - return getComplexIntValues(); - } - - /// Return the held element values as a range of APFloat. The element type of - /// this attribute must be of float type. - llvm::iterator_range getFloatValues() const; - template ::value>::type> - llvm::iterator_range getValues() const { - return getFloatValues(); - } - FloatElementIterator float_value_begin() const; - FloatElementIterator float_value_end() const; - - /// Return the held element values as a range of complex APFloat. The element - /// type of this attribute must be a complex of float type. - llvm::iterator_range - getComplexFloatValues() const; - template >::value>::type> - llvm::iterator_range getValues() const { - return getComplexFloatValues(); - } - - /// Return the raw storage data held by this attribute. Users should generally - /// not use this directly, as the internal storage format is not always in the - /// form the user might expect. - ArrayRef getRawData() const; - - /// Return the raw StringRef data held by this attribute. - ArrayRef getRawStringData() const; - - //===--------------------------------------------------------------------===// - // Mutation Utilities - //===--------------------------------------------------------------------===// - - /// Return a new DenseElementsAttr that has the same data as the current - /// attribute, but has been reshaped to 'newType'. The new type must have the - /// same total number of elements as well as element type. - DenseElementsAttr reshape(ShapedType newType); - - /// Generates a new DenseElementsAttr by mapping each int value to a new - /// underlying APInt. The new values can represent either an integer or float. - /// This underlying type must be an DenseIntElementsAttr. - DenseElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Generates a new DenseElementsAttr by mapping each float value to a new - /// underlying APInt. the new values can represent either an integer or float. - /// This underlying type must be an DenseFPElementsAttr. - DenseElementsAttr - mapValues(Type newElementType, - function_ref mapping) const; - -protected: - /// Get iterators to the raw APInt values for each element in this attribute. - IntElementIterator raw_int_begin() const { - return IntElementIterator(*this, 0); - } - IntElementIterator raw_int_end() const { - return IntElementIterator(*this, getNumElements()); - } - - /// Overload of the raw 'get' method that asserts that the given type is of - /// complex type. This method is used to verify type invariants that the - /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); - - /// Overload of the raw 'get' method that asserts that the given type is of - /// integer or floating-point type. This method is used to verify type - /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, - ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); - - /// Check the information for a C++ data type, check if this type is valid for - /// the current attribute. This method is used to verify specific type - /// invariants that the templatized 'getValues' method cannot. - bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; - - /// Check the information for a C++ data type, check if this type is valid for - /// the current attribute. This method is used to verify specific type - /// invariants that the templatized 'getValues' method cannot. - bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; -}; - -/// An attribute class for representing dense arrays of strings. The structure -/// storing and querying a list of densely packed strings. -class DenseStringElementsAttr - : public Attribute::AttrBase { - -public: - using Base::Base; - - /// Overload of the raw 'get' method that asserts that the given type is of - /// integer or floating-point type. This method is used to verify type - /// invariants that the templatized 'get' method cannot. - static DenseStringElementsAttr get(ShapedType type, ArrayRef data); - -protected: - friend DenseElementsAttr; -}; - -/// An attribute class for specializing behavior of Int and Floating-point -/// densely packed string arrays. -class DenseIntOrFPElementsAttr - : public Attribute::AttrBase { - -public: - using Base::Base; - - /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of - /// the elements of `inRawData` has `type`. If `inRawData` is little endian - /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is - /// BE, converted to LE. - static void - convertEndianOfArrayRefForBEmachine(ArrayRef inRawData, - MutableArrayRef outRawData, - ShapedType type); - - /// Convert endianess of input for big-endian(BE) machines. The number of - /// elements of `inRawData` is `numElements`, and each element has - /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is - /// converted to big endian (BE) and saved in `outRawData`. Conversely, if - /// `inRawData` is BE, converted to LE. - static void convertEndianOfCharForBEmachine(const char *inRawData, - char *outRawData, - size_t elementBitWidth, - size_t numElements); - -protected: - friend DenseElementsAttr; - - /// Constructs a dense elements attribute from an array of raw APFloat values. - /// Each APFloat value is expected to have the same bitwidth as the element - /// type of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); - - /// Constructs a dense elements attribute from an array of raw APInt values. - /// Each APInt value is expected to have the same bitwidth as the element type - /// of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); - - /// Get or create a new dense elements attribute instance with the given raw - /// data buffer. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef data, - bool isSplat); - - /// Overload of the raw 'get' method that asserts that the given type is of - /// complex type. This method is used to verify type invariants that the - /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); - - /// Overload of the raw 'get' method that asserts that the given type is of - /// integer or floating-point type. This method is used to verify type - /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, - ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); -}; - -/// An attribute that represents a reference to a dense float vector or tensor -/// object. Each element is stored as a double. -class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { -public: - using iterator = DenseElementsAttr::FloatElementIterator; - - using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; - - /// Get an instance of a DenseFPElementsAttr with the given arguments. This - /// simply wraps the DenseElementsAttr::get calls. - template - static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) - .template cast(); - } - template - static DenseFPElementsAttr get(const ShapedType &type, - const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); - } - - /// Generates a new DenseElementsAttr by mapping each value attribute, and - /// constructing the DenseElementsAttr given the new element type. - DenseElementsAttr - mapValues(Type newElementType, - function_ref mapping) const; - - /// Iterator access to the float element values. - iterator begin() const { return float_value_begin(); } - iterator end() const { return float_value_end(); } - - /// Method for supporting type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr); -}; - -/// An attribute that represents a reference to a dense integer vector or tensor -/// object. -class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { -public: - /// DenseIntElementsAttr iterates on APInt, so we can use the raw element - /// iterator directly. - using iterator = DenseElementsAttr::IntElementIterator; - - using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; - - /// Get an instance of a DenseIntElementsAttr with the given arguments. This - /// simply wraps the DenseElementsAttr::get calls. - template - static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { - return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) - .template cast(); - } - template - static DenseIntElementsAttr get(const ShapedType &type, - const std::initializer_list &list) { - return DenseElementsAttr::get(type, list) - .template cast(); - } - - /// Generates a new DenseElementsAttr by mapping each value attribute, and - /// constructing the DenseElementsAttr given the new element type. - DenseElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Iterator access to the integer element values. - iterator begin() const { return raw_int_begin(); } - iterator end() const { return raw_int_end(); } - - /// Method for supporting type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr); -}; - -/// An opaque attribute that represents a reference to a vector or tensor -/// constant with opaque content. This representation is for tensor constants -/// which the compiler may not need to interpret. This attribute is always -/// associated with a particular dialect, which provides a method to convert -/// tensor representation to a non-opaque format. -class OpaqueElementsAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = StringRef; - - static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, - StringRef bytes); - - StringRef getValue() const; - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const; - - /// Decodes the attribute value using dialect-specific decoding hook. - /// Returns false if decoding is successful. If not, returns true and leaves - /// 'result' argument unspecified. - bool decode(ElementsAttr &result); - - /// Returns dialect associated with this opaque constant. - Dialect *getDialect() const; -}; - -/// An attribute that represents a reference to a sparse vector or tensor -/// object. -/// -/// This class uses COO (coordinate list) encoding to represent the sparse -/// elements in an element attribute. Specifically, the sparse vector/tensor -/// stores the indices and values as two separate dense elements attributes of -/// tensor type (even if the sparse attribute is of vector type, in order to -/// support empty lists). The dense elements attribute indices is a 2-D tensor -/// of 64-bit integer elements with shape [N, ndims], which specifies the -/// indices of the elements in the sparse tensor that contains nonzero values. -/// The dense elements attribute values is a 1-D tensor with shape [N], and it -/// supplies the corresponding values for the indices. -/// -/// For example, -/// `sparse, [[0, 0], [1, 2]], [1, 5]>` represents tensor -/// [[1, 0, 0, 0], -/// [0, 0, 5, 0], -/// [0, 0, 0, 0]]. -class SparseElementsAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - template - using iterator = - llvm::mapped_iterator, - std::function>; - - /// 'type' must be a vector or tensor with static shape. - static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices, - DenseElementsAttr values); - - DenseIntElementsAttr getIndices() const; - - DenseElementsAttr getValues() const; - - /// Return the values of this attribute in the form of the given type 'T'. 'T' - /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. - template llvm::iterator_range> getValues() const { - auto zeroValue = getZeroValue(); - auto valueIt = getValues().getValues().begin(); - const std::vector flatSparseIndices(getFlattenedSparseIndices()); - // TODO: Move-capture flatSparseIndices when c++14 is available. - std::function mapFn = [=](ptrdiff_t index) { - // Try to map the current index to one of the sparse indices. - for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) - if (flatSparseIndices[i] == index) - return *std::next(valueIt, i); - // Otherwise, return the zero value. - return zeroValue; - }; - return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); - } - - /// Return the value of the element at the given index. The 'index' is - /// expected to refer to a valid element. - Attribute getValue(ArrayRef index) const; - -private: - /// Get a zero APFloat for the given sparse attribute. - APFloat getZeroAPFloat() const; - - /// Get a zero APInt for the given sparse attribute. - APInt getZeroAPInt() const; - - /// Get a zero attribute for the given sparse attribute. - Attribute getZeroAttr() const; - - /// Utility methods to generate a zero value of some type 'T'. This is used by - /// the 'iterator' class. - /// Get a zero for a given attribute type. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAttr().template cast(); - } - /// Get a zero for an APInt. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAPInt(); - } - template - typename std::enable_if, T>::value, T>::type - getZeroValue() const { - APInt intZero = getZeroAPInt(); - return {intZero, intZero}; - } - /// Get a zero for an APFloat. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAPFloat(); - } - template - typename std::enable_if, T>::value, - T>::type - getZeroValue() const { - APFloat floatZero = getZeroAPFloat(); - return {floatZero, floatZero}; - } - - /// Get a zero for an C++ integer, float, StringRef, or complex type. - template - typename std::enable_if< - std::numeric_limits::is_integer || - DenseElementsAttr::is_valid_cpp_fp_type::value || - std::is_same::value || - (detail::is_complex_t::value && - !llvm::is_one_of, - std::complex>::value), - T>::type - getZeroValue() const { - return T(); - } - - /// Flatten, and return, all of the sparse indices in this attribute in - /// row-major order. - std::vector getFlattenedSparseIndices() const; -}; - -/// An attribute that represents a reference to a splat vector or tensor -/// constant, meaning all of the elements have the same value. -class SplatElementsAttr : public DenseElementsAttr { -public: - using DenseElementsAttr::DenseElementsAttr; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr) { - auto denseAttr = attr.dyn_cast(); - return denseAttr && denseAttr.isSplat(); - } -}; - -namespace detail { -/// This class represents a general iterator over the values of an ElementsAttr. -/// It supports all subclasses aside from OpaqueElementsAttr. -template -class ElementsAttrIterator - : public llvm::iterator_facade_base, - std::random_access_iterator_tag, T, - std::ptrdiff_t, T, T> { - // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype' - // inside of a conversion operator. - using DenseIteratorT = typename std::enable_if< - true, - decltype(std::declval().getValues().begin())>::type; - using SparseIteratorT = SparseElementsAttr::iterator; - - /// A union containing the specific iterators for each derived attribute kind. - union Iterator { - Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {} - Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {} - Iterator() {} - ~Iterator() {} - - operator const DenseIteratorT &() const { return denseIt; } - operator const SparseIteratorT &() const { return sparseIt; } - operator DenseIteratorT &() { return denseIt; } - operator SparseIteratorT &() { return sparseIt; } - - /// An instance of a dense elements iterator. - DenseIteratorT denseIt; - /// An instance of a sparse elements iterator. - SparseIteratorT sparseIt; - }; - - /// Utility method to process a functor on each of the internal iterator - /// types. - template class ProcessFn, - typename... Args> - RetT process(Args &... args) const { - if (attr.isa()) - return ProcessFn()(args...); - if (attr.isa()) - return ProcessFn()(args...); - llvm_unreachable("unexpected attribute kind"); - } - - /// Utility functors used to generically implement the iterators methods. - template struct PlusAssign { - void operator()(ItT &it, ptrdiff_t offset) { it += offset; } - }; - template struct Minus { - ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } - }; - template struct MinusAssign { - void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } - }; - template struct Dereference { - T operator()(ItT &it) { return *it; } - }; - template struct ConstructIter { - void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } - }; - template struct DestructIter { - void operator()(ItT &it) { it.~ItT(); } - }; - -public: - ElementsAttrIterator(const ElementsAttrIterator &rhs) : attr(rhs.attr) { - process(it, rhs.it); - } - ~ElementsAttrIterator() { process(it); } - - /// Methods necessary to support random access iteration. - ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { - assert(attr == rhs.attr && "incompatible iterators"); - return process(it, rhs.it); - } - bool operator==(const ElementsAttrIterator &rhs) const { - return rhs.attr == attr && process(it, rhs.it); - } - bool operator<(const ElementsAttrIterator &rhs) const { - assert(attr == rhs.attr && "incompatible iterators"); - return process(it, rhs.it); - } - ElementsAttrIterator &operator+=(ptrdiff_t offset) { - process(it, offset); - return *this; - } - ElementsAttrIterator &operator-=(ptrdiff_t offset) { - process(it, offset); - return *this; - } - - /// Dereference the iterator at the current index. - T operator*() { return process(it); } - -private: - template - ElementsAttrIterator(Attribute attr, IteratorT &&it) - : attr(attr), it(std::forward(it)) {} - - /// Allow accessing the constructor. - friend ElementsAttr; - - /// The parent elements attribute. - Attribute attr; - - /// A union containing the specific iterators for each derived kind. - Iterator it; -}; - -template -class ElementsAttrRange : public llvm::iterator_range> { - using llvm::iterator_range>::iterator_range; -}; -} // namespace detail - -/// Return the elements of this attribute as a value of type 'T'. -template -auto ElementsAttr::getValues() const -> iterator_range { - if (DenseElementsAttr denseAttr = dyn_cast()) { - auto values = denseAttr.getValues(); - return {iterator(*this, values.begin()), - iterator(*this, values.end())}; - } - if (SparseElementsAttr sparseAttr = dyn_cast()) { - auto values = sparseAttr.getValues(); - return {iterator(*this, values.begin()), - iterator(*this, values.end())}; - } - llvm_unreachable("unexpected attribute kind"); -} - -//===----------------------------------------------------------------------===// -// Attributes Utils -//===----------------------------------------------------------------------===// - template bool Attribute::isa() const { assert(impl && "isa<> used on a null attribute."); return U::classof(*this); @@ -1610,80 +116,58 @@ return U(impl); } -// Make Attribute hashable. inline ::llvm::hash_code hash_value(Attribute arg) { return ::llvm::hash_value(arg.impl); } //===----------------------------------------------------------------------===// -// MutableDictionaryAttr +// NamedAttribute //===----------------------------------------------------------------------===// -/// A MutableDictionaryAttr is a mutable wrapper around a DictionaryAttr. It -/// provides additional interfaces for adding, removing, replacing attributes -/// within a DictionaryAttr. -/// -/// We assume there will be relatively few attributes on a given operation -/// (maybe a dozen or so, but not hundreds or thousands) so we use linear -/// searches for everything. -class MutableDictionaryAttr { -public: - MutableDictionaryAttr(DictionaryAttr attrs = nullptr) - : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {} - MutableDictionaryAttr(ArrayRef attributes); - - bool operator!=(const MutableDictionaryAttr &other) const { - return !(*this == other); - } - bool operator==(const MutableDictionaryAttr &other) const { - return attrs == other.attrs; - } - - /// Return the underlying dictionary attribute. - DictionaryAttr getDictionary(MLIRContext *context) const; - - /// Return the underlying dictionary attribute or null if there are no - /// attributes within this dictionary. - DictionaryAttr getDictionaryOrNull() const { return attrs; } - - /// Return all of the attributes on this operation. - ArrayRef getAttrs() const; - - /// Replace the held attributes with ones provided in 'newAttrs'. - void setAttrs(ArrayRef attributes); - - /// Return the specified attribute if present, null otherwise. - Attribute get(StringRef name) const; - Attribute get(Identifier name) const; +/// NamedAttribute is combination of a name, represented by an Identifier, and a +/// value, represented by an Attribute. The attribute pointer should always be +/// non-null. +using NamedAttribute = std::pair; - /// Return the specified named attribute if present, None otherwise. - Optional getNamed(StringRef name) const; - Optional getNamed(Identifier name) const; +bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); +bool operator<(const NamedAttribute &lhs, StringRef rhs); - /// If the an attribute exists with the specified name, change it to the new - /// value. Otherwise, add a new attribute with the specified name/value. - void set(Identifier name, Attribute value); +//===----------------------------------------------------------------------===// +// AttributeTraitBase +//===----------------------------------------------------------------------===// - enum class RemoveResult { Removed, NotFound }; +namespace AttributeTrait { +/// This class represents the base of an attribute trait. +template class TraitType> +using TraitBase = detail::StorageUserTraitBase; +} // namespace AttributeTrait - /// Remove the attribute with the specified name if it exists. The return - /// value indicates whether the attribute was present or not. - RemoveResult remove(Identifier name); +//===----------------------------------------------------------------------===// +// AttributeInterface +//===----------------------------------------------------------------------===// - bool empty() const { return attrs == nullptr; } +/// This class represents the base of an attribute interface. See the definition +/// of `detail::Interface` for requirements on the `Traits` type. +template +class AttributeInterface + : public detail::Interface { +public: + using Base = AttributeInterface; + using InterfaceBase = detail::Interface; + using InterfaceBase::InterfaceBase; private: - friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg); + /// Returns the impl interface instance for the given type. + static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { + return attr.getAbstractAttribute().getInterface(); + } - DictionaryAttr attrs; + /// Allow access to 'getInterfaceFor'. + friend InterfaceBase; }; -inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) { - if (!arg.attrs) - return ::llvm::hash_value((void *)nullptr); - return hash_value(arg.attrs); -} - } // end namespace mlir. namespace llvm { @@ -1718,15 +202,6 @@ mlir::AttributeStorage *>::NumLowBitsAvailable; }; -template <> -struct PointerLikeTypeTraits - : public PointerLikeTypeTraits { - static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) { - return PointerLikeTypeTraits::getFromVoidPointer(ptr) - .cast(); - } -}; - } // namespace llvm #endif diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h copy from mlir/include/mlir/IR/Attributes.h copy to mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -1,4 +1,4 @@ -//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===// +//===- BuiltinAttributes.h - MLIR Builtin Attribute Classes -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,25 +6,20 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_IR_ATTRIBUTES_H -#define MLIR_IR_ATTRIBUTES_H +#ifndef MLIR_IR_BUILTINATTRIBUTES_H +#define MLIR_IR_BUILTINATTRIBUTES_H -#include "mlir/IR/AttributeSupport.h" +#include "mlir/IR/Attributes.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" -#include "llvm/Support/PointerLikeTypeTraits.h" #include namespace mlir { class AffineMap; -class Dialect; class FunctionType; -class Identifier; class IntegerSet; class Location; -class MLIRContext; class ShapedType; -class Type; namespace detail { @@ -46,122 +41,6 @@ struct SparseElementsAttributeStorage; } // namespace detail -/// Attributes are known-constant values of operations and functions. -/// -/// Instances of the Attribute class are references to immortal key-value pairs -/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a -/// thin wrapper around an underlying storage pointer. Attributes are usually -/// passed by value. -class Attribute { -public: - /// Utility class for implementing attributes. - template class... Traits> - using AttrBase = detail::StorageUserBase; - - using ImplType = AttributeStorage; - using ValueType = void; - - constexpr Attribute() : impl(nullptr) {} - /* implicit */ Attribute(const ImplType *impl) - : impl(const_cast(impl)) {} - - Attribute(const Attribute &other) = default; - Attribute &operator=(const Attribute &other) = default; - - bool operator==(Attribute other) const { return impl == other.impl; } - bool operator!=(Attribute other) const { return !(*this == other); } - explicit operator bool() const { return impl; } - - bool operator!() const { return impl == nullptr; } - - template bool isa() const; - template - bool isa() const; - template U dyn_cast() const; - template U dyn_cast_or_null() const; - template U cast() const; - - // Support dyn_cast'ing Attribute to itself. - static bool classof(Attribute) { return true; } - - /// Return a unique identifier for the concrete attribute type. This is used - /// to support dynamic type casting. - TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } - - /// Return the type of this attribute. - Type getType() const; - - /// Return the context this attribute belongs to. - MLIRContext *getContext() const; - - /// Get the dialect this attribute is registered to. - Dialect &getDialect() const; - - /// Print the attribute. - void print(raw_ostream &os) const; - void dump() const; - - /// Get an opaque pointer to the attribute. - const void *getAsOpaquePointer() const { return impl; } - /// Construct an attribute from the opaque pointer representation. - static Attribute getFromOpaquePointer(const void *ptr) { - return Attribute(reinterpret_cast(ptr)); - } - - friend ::llvm::hash_code hash_value(Attribute arg); - - /// Return the abstract descriptor for this attribute. - const AbstractAttribute &getAbstractAttribute() const { - return impl->getAbstractAttribute(); - } - -protected: - ImplType *impl; -}; - -inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) { - attr.print(os); - return os; -} - -//===----------------------------------------------------------------------===// -// AttributeTraitBase -//===----------------------------------------------------------------------===// - -namespace AttributeTrait { -/// This class represents the base of an attribute trait. -template class TraitType> -using TraitBase = detail::StorageUserTraitBase; -} // namespace AttributeTrait - -//===----------------------------------------------------------------------===// -// AttributeInterface -//===----------------------------------------------------------------------===// - -/// This class represents the base of an attribute interface. See the definition -/// of `detail::Interface` for requirements on the `Traits` type. -template -class AttributeInterface - : public detail::Interface { -public: - using Base = AttributeInterface; - using InterfaceBase = detail::Interface; - using InterfaceBase::InterfaceBase; - -private: - /// Returns the impl interface instance for the given type. - static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { - return attr.getAbstractAttribute().getInterface(); - } - - /// Allow access to 'getInterfaceFor'. - friend InterfaceBase; -}; - //===----------------------------------------------------------------------===// // AffineMapAttr //===----------------------------------------------------------------------===// @@ -233,14 +112,6 @@ // DictionaryAttr //===----------------------------------------------------------------------===// -/// NamedAttribute is used for dictionary attributes, it holds an identifier for -/// the name and a value for the attribute. The attribute pointer should always -/// be non-null. -using NamedAttribute = std::pair; - -bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); -bool operator<(const NamedAttribute &lhs, StringRef rhs); - /// Dictionary attribute is an attribute that represents a sorted collection of /// named attribute values. The elements are sorted by name, and each name must /// be unique within the collection. @@ -558,8 +429,10 @@ //===----------------------------------------------------------------------===// namespace detail { -template class ElementsAttrIterator; -template class ElementsAttrRange; +template +class ElementsAttrIterator; +template +class ElementsAttrRange; } // namespace detail /// A base attribute that represents a reference to a static shaped tensor or @@ -567,8 +440,10 @@ class ElementsAttr : public Attribute { public: using Attribute::Attribute; - template using iterator = detail::ElementsAttrIterator; - template using iterator_range = detail::ElementsAttrRange; + template + using iterator = detail::ElementsAttrIterator; + template + using iterator_range = detail::ElementsAttrRange; /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. @@ -580,14 +455,16 @@ /// Return the value of type 'T' at the given index, where 'T' corresponds to /// an Attribute type. - template T getValue(ArrayRef index) const { + template + T getValue(ArrayRef index) const { return getValue(index).template cast(); } /// Return the elements of this attribute as a value of type 'T'. Note: /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support /// iteration. - template iterator_range getValues() const; + template + iterator_range getValues() const; /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const; @@ -664,7 +541,8 @@ }; /// Type trait detector that checks if a given type T is a complex type. -template struct is_complex_t : public std::false_type {}; +template +struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; } // namespace detail @@ -679,7 +557,8 @@ /// floating point type that can be used to access the underlying element /// types of a DenseElementsAttr. // TODO: Use std::disjunction when C++17 is supported. - template struct is_valid_cpp_fp_type { + template + struct is_valid_cpp_fp_type { /// The type is a valid floating point type if it is a builtin floating /// point type, or is a potentially user defined floating point type. The /// latter allows for supporting users that have custom types defined for @@ -948,7 +827,8 @@ Attribute getValue(ArrayRef index) const { return getValue(index); } - template T getValue(ArrayRef index) const { + template + T getValue(ArrayRef index) const { // Skip to the element corresponding to the flattened index. return *std::next(getValues().begin(), getFlattenedIndex(index)); } @@ -1357,7 +1237,8 @@ /// Return the values of this attribute in the form of the given type 'T'. 'T' /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. - template llvm::iterator_range> getValues() const { + template + llvm::iterator_range> getValues() const { auto zeroValue = getZeroValue(); auto valueIt = getValues().getValues().begin(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); @@ -1490,7 +1371,7 @@ /// types. template class ProcessFn, typename... Args> - RetT process(Args &... args) const { + RetT process(Args &...args) const { if (attr.isa()) return ProcessFn()(args...); if (attr.isa()) @@ -1499,22 +1380,28 @@ } /// Utility functors used to generically implement the iterators methods. - template struct PlusAssign { + template + struct PlusAssign { void operator()(ItT &it, ptrdiff_t offset) { it += offset; } }; - template struct Minus { + template + struct Minus { ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } }; - template struct MinusAssign { + template + struct MinusAssign { void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } }; - template struct Dereference { + template + struct Dereference { T operator()(ItT &it) { return *it; } }; - template struct ConstructIter { + template + struct ConstructIter { void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } }; - template struct DestructIter { + template + struct DestructIter { void operator()(ItT &it) { it.~ItT(); } }; @@ -1585,36 +1472,6 @@ llvm_unreachable("unexpected attribute kind"); } -//===----------------------------------------------------------------------===// -// Attributes Utils -//===----------------------------------------------------------------------===// - -template bool Attribute::isa() const { - assert(impl && "isa<> used on a null attribute."); - return U::classof(*this); -} - -template -bool Attribute::isa() const { - return isa() || isa(); -} - -template U Attribute::dyn_cast() const { - return isa() ? U(impl) : U(nullptr); -} -template U Attribute::dyn_cast_or_null() const { - return (impl && isa()) ? U(impl) : U(nullptr); -} -template U Attribute::cast() const { - assert(isa()); - return U(impl); -} - -// Make Attribute hashable. -inline ::llvm::hash_code hash_value(Attribute arg) { - return ::llvm::hash_value(arg.impl); -} - //===----------------------------------------------------------------------===// // MutableDictionaryAttr //===----------------------------------------------------------------------===// @@ -1688,36 +1545,6 @@ namespace llvm { -// Attribute hash just like pointers. -template <> struct DenseMapInfo { - static mlir::Attribute getEmptyKey() { - auto pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::Attribute(static_cast(pointer)); - } - static mlir::Attribute getTombstoneKey() { - auto pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::Attribute(static_cast(pointer)); - } - static unsigned getHashValue(mlir::Attribute val) { - return mlir::hash_value(val); - } - static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) { - return LHS == RHS; - } -}; - -/// Allow LLVM to steal the low bits of Attributes. -template <> struct PointerLikeTypeTraits { - static inline void *getAsVoidPointer(mlir::Attribute attr) { - return const_cast(attr.getAsOpaquePointer()); - } - static inline mlir::Attribute getFromVoidPointer(void *ptr) { - return mlir::Attribute::getFromOpaquePointer(ptr); - } - static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits< - mlir::AttributeStorage *>::NumLowBitsAvailable; -}; - template <> struct PointerLikeTypeTraits : public PointerLikeTypeTraits { @@ -1729,4 +1556,4 @@ } // namespace llvm -#endif +#endif // MLIR_IR_BUILTINATTRIBUTES_H diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -14,6 +14,7 @@ #define MLIR_IR_OPERATION_H #include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Region.h" diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -30,6 +30,9 @@ namespace mlir { class Dialect; +class DictionaryAttr; +class ElementsAttr; +class MutableDictionaryAttr; class Operation; struct OperationState; class OpAsmParser; diff --git a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h --- a/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h +++ b/mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h @@ -8,7 +8,7 @@ #ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_ #define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_ -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectInterface.h" #include "mlir/Support/LogicalResult.h" diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -21,6 +21,7 @@ namespace mlir { class AffineForOp; +class AffineMap; class FuncOp; class LoopLikeOpInterface; struct MemRefRegion; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -12,9 +12,9 @@ #include "PybindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Registration.h" -#include "mlir-c/StandardAttributes.h" #include "llvm/ADT/SmallVector.h" #include @@ -1390,7 +1390,7 @@ } // end namespace //------------------------------------------------------------------------------ -// Standard attribute subclasses. +// Builtin attribute subclasses. //------------------------------------------------------------------------------ namespace { @@ -3039,7 +3039,7 @@ py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); - // Standard attribute bindings. + // Builtin attribute bindings. PyFloatAttribute::bind(m); PyIntegerAttribute::bind(m); PyBoolAttribute::bind(m); diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp rename from mlir/lib/CAPI/IR/StandardAttributes.cpp rename to mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/StandardAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -1,4 +1,4 @@ -//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===// +//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/StandardAttributes.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -2,11 +2,11 @@ add_mlir_public_c_api_library(MLIRCAPIIR AffineExpr.cpp AffineMap.cpp + BuiltinAttributes.cpp BuiltinTypes.cpp Diagnostics.cpp IR.cpp Pass.cpp - StandardAttributes.cpp Support.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -8,7 +8,7 @@ #include "mlir/IR/AffineMap.h" #include "AffineMapDetail.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -14,7 +14,7 @@ #define ATTRIBUTEDETAIL_H_ #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -7,17 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" -#include "AttributeDetail.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Types.h" -#include "mlir/Interfaces/DecodeAttributesInterfaces.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/Endian.h" using namespace mlir; using namespace mlir::detail; @@ -53,1550 +43,9 @@ } //===----------------------------------------------------------------------===// -// AffineMapAttr +// NamedAttribute //===----------------------------------------------------------------------===// -AffineMapAttr AffineMapAttr::get(AffineMap value) { - return Base::get(value.getContext(), value); -} - -AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } - -//===----------------------------------------------------------------------===// -// ArrayAttr -//===----------------------------------------------------------------------===// - -ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { - return Base::get(context, value); -} - -ArrayRef ArrayAttr::getValue() const { return getImpl()->value; } - -Attribute ArrayAttr::operator[](unsigned idx) const { - assert(idx < size() && "index out of bounds"); - return getValue()[idx]; -} - -//===----------------------------------------------------------------------===// -// DictionaryAttr -//===----------------------------------------------------------------------===// - -/// Helper function that does either an in place sort or sorts from source array -/// into destination. If inPlace then storage is both the source and the -/// destination, else value is the source and storage destination. Returns -/// whether source was sorted. -template -static bool dictionaryAttrSort(ArrayRef value, - SmallVectorImpl &storage) { - // Specialize for the common case. - switch (value.size()) { - case 0: - // Zero already sorted. - break; - case 1: - // One already sorted but may need to be copied. - if (!inPlace) - storage.assign({value[0]}); - break; - case 2: { - bool isSorted = value[0] < value[1]; - if (inPlace) { - if (!isSorted) - std::swap(storage[0], storage[1]); - } else if (isSorted) { - storage.assign({value[0], value[1]}); - } else { - storage.assign({value[1], value[0]}); - } - return !isSorted; - } - default: - if (!inPlace) - storage.assign(value.begin(), value.end()); - // Check to see they are sorted already. - bool isSorted = llvm::is_sorted(value); - if (!isSorted) { - // If not, do a general sort. - llvm::array_pod_sort(storage.begin(), storage.end()); - value = storage; - } - return !isSorted; - } - return false; -} - -/// Returns an entry with a duplicate name from the given sorted array of named -/// attributes. Returns llvm::None if all elements have unique names. -static Optional -findDuplicateElement(ArrayRef value) { - const Optional none{llvm::None}; - if (value.size() < 2) - return none; - - if (value.size() == 2) - return value[0].first == value[1].first ? value[0] : none; - - auto it = std::adjacent_find( - value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); - return it != value.end() ? *it : none; -} - -bool DictionaryAttr::sort(ArrayRef value, - SmallVectorImpl &storage) { - bool isSorted = dictionaryAttrSort(value, storage); - assert(!findDuplicateElement(storage) && - "DictionaryAttr element names must be unique"); - return isSorted; -} - -bool DictionaryAttr::sortInPlace(SmallVectorImpl &array) { - bool isSorted = dictionaryAttrSort(array, array); - assert(!findDuplicateElement(array) && - "DictionaryAttr element names must be unique"); - return isSorted; -} - -Optional -DictionaryAttr::findDuplicate(SmallVectorImpl &array, - bool isSorted) { - if (!isSorted) - dictionaryAttrSort(array, array); - return findDuplicateElement(array); -} - -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { - if (value.empty()) - return DictionaryAttr::getEmpty(context); - assert(llvm::all_of(value, - [](const NamedAttribute &attr) { return attr.second; }) && - "value cannot have null entries"); - - // We need to sort the element list to canonicalize it. - SmallVector storage; - if (dictionaryAttrSort(value, storage)) - value = storage; - assert(!findDuplicateElement(value) && - "DictionaryAttr element names must be unique"); - return Base::get(context, value); -} -/// Construct a dictionary with an array of values that is known to already be -/// sorted by name and uniqued. -DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef value, - MLIRContext *context) { - if (value.empty()) - return DictionaryAttr::getEmpty(context); - // Ensure that the attribute elements are unique and sorted. - assert(llvm::is_sorted(value, - [](NamedAttribute l, NamedAttribute r) { - return l.first.strref() < r.first.strref(); - }) && - "expected attribute values to be sorted"); - assert(!findDuplicateElement(value) && - "DictionaryAttr element names must be unique"); - return Base::get(context, value); -} - -ArrayRef DictionaryAttr::getValue() const { - return getImpl()->getElements(); -} - -/// Return the specified attribute if present, null otherwise. -Attribute DictionaryAttr::get(StringRef name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; -} -Attribute DictionaryAttr::get(Identifier name) const { - Optional attr = getNamed(name); - return attr ? attr->second : nullptr; -} - -/// Return the specified named attribute if present, None otherwise. -Optional DictionaryAttr::getNamed(StringRef name) const { - ArrayRef values = getValue(); - const auto *it = llvm::lower_bound(values, name); - return it != values.end() && it->first == name ? *it - : Optional(); -} -Optional DictionaryAttr::getNamed(Identifier name) const { - for (auto elt : getValue()) - if (elt.first == name) - return elt; - return llvm::None; -} - -DictionaryAttr::iterator DictionaryAttr::begin() const { - return getValue().begin(); -} -DictionaryAttr::iterator DictionaryAttr::end() const { - return getValue().end(); -} -size_t DictionaryAttr::size() const { return getValue().size(); } - -//===----------------------------------------------------------------------===// -// FloatAttr -//===----------------------------------------------------------------------===// - -FloatAttr FloatAttr::get(Type type, double value) { - return Base::get(type.getContext(), type, value); -} - -FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { - return Base::getChecked(loc, type, value); -} - -FloatAttr FloatAttr::get(Type type, const APFloat &value) { - return Base::get(type.getContext(), type, value); -} - -FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { - return Base::getChecked(loc, type, value); -} - -APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } - -double FloatAttr::getValueAsDouble() const { - return getValueAsDouble(getValue()); -} -double FloatAttr::getValueAsDouble(APFloat value) { - if (&value.getSemantics() != &APFloat::IEEEdouble()) { - bool losesInfo = false; - value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, - &losesInfo); - } - return value.convertToDouble(); -} - -/// Verify construction invariants. -static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { - if (!type.isa()) - return emitError(loc, "expected floating point type"); - return success(); -} - -LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, - double value) { - return verifyFloatTypeInvariants(loc, type); -} - -LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, - const APFloat &value) { - // Verify that the type is correct. - if (failed(verifyFloatTypeInvariants(loc, type))) - return failure(); - - // Verify that the type semantics match that of the value. - if (&type.cast().getFloatSemantics() != &value.getSemantics()) { - return emitError( - loc, "FloatAttr type doesn't match the type implied by its value"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// SymbolRefAttr -//===----------------------------------------------------------------------===// - -FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, value, llvm::None).cast(); -} - -SymbolRefAttr SymbolRefAttr::get(StringRef value, - ArrayRef nestedReferences, - MLIRContext *ctx) { - return Base::get(ctx, value, nestedReferences); -} - -StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } - -StringRef SymbolRefAttr::getLeafReference() const { - ArrayRef nestedRefs = getNestedReferences(); - return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); -} - -ArrayRef SymbolRefAttr::getNestedReferences() const { - return getImpl()->getNestedRefs(); -} - -//===----------------------------------------------------------------------===// -// IntegerAttr -//===----------------------------------------------------------------------===// - -IntegerAttr IntegerAttr::get(Type type, const APInt &value) { - if (type.isSignlessInteger(1)) - return BoolAttr::get(value.getBoolValue(), type.getContext()); - return Base::get(type.getContext(), type, value); -} - -IntegerAttr IntegerAttr::get(Type type, int64_t value) { - // This uses 64 bit APInts by default for index type. - if (type.isIndex()) - return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); - - auto intType = type.cast(); - return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); -} - -APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } - -int64_t IntegerAttr::getInt() const { - assert((getImpl()->getType().isIndex() || - getImpl()->getType().isSignlessInteger()) && - "must be signless integer"); - return getValue().getSExtValue(); -} - -int64_t IntegerAttr::getSInt() const { - assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); - return getValue().getSExtValue(); -} - -uint64_t IntegerAttr::getUInt() const { - assert(getImpl()->getType().isUnsignedInteger() && - "must be unsigned integer"); - return getValue().getZExtValue(); -} - -static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { - if (type.isa()) - return success(); - return emitError(loc, "expected integer or index type"); -} - -LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, - int64_t value) { - return verifyIntegerTypeInvariants(loc, type); -} - -LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, - const APInt &value) { - if (failed(verifyIntegerTypeInvariants(loc, type))) - return failure(); - if (auto integerType = type.dyn_cast()) - if (integerType.getWidth() != value.getBitWidth()) - return emitError(loc, "integer type bit width (") - << integerType.getWidth() << ") doesn't match value bit width (" - << value.getBitWidth() << ")"; - return success(); -} - -//===----------------------------------------------------------------------===// -// BoolAttr - -bool BoolAttr::getValue() const { - auto *storage = reinterpret_cast(impl); - return storage->getValue().getBoolValue(); -} - -bool BoolAttr::classof(Attribute attr) { - IntegerAttr intAttr = attr.dyn_cast(); - return intAttr && intAttr.getType().isSignlessInteger(1); -} - -//===----------------------------------------------------------------------===// -// IntegerSetAttr -//===----------------------------------------------------------------------===// - -IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { - return Base::get(value.getConstraint(0).getContext(), value); -} - -IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } - -//===----------------------------------------------------------------------===// -// OpaqueAttr -//===----------------------------------------------------------------------===// - -OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context) { - return Base::get(context, dialect, attrData, type); -} - -OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, - Type type, Location location) { - return Base::getChecked(location, dialect, attrData, type); -} - -/// Returns the dialect namespace of the opaque attribute. -Identifier OpaqueAttr::getDialectNamespace() const { - return getImpl()->dialectNamespace; -} - -/// Returns the raw attribute data of the opaque attribute. -StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } - -/// Verify the construction of an opaque attribute. -LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, - Identifier dialect, - StringRef attrData, - Type type) { - if (!Dialect::isValidNamespace(dialect.strref())) - return emitError(loc, "invalid dialect namespace '") << dialect << "'"; - return success(); -} - -//===----------------------------------------------------------------------===// -// StringAttr -//===----------------------------------------------------------------------===// - -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { - return get(bytes, NoneType::get(context)); -} - -/// Get an instance of a StringAttr with the given string and Type. -StringAttr StringAttr::get(StringRef bytes, Type type) { - return Base::get(type.getContext(), bytes, type); -} - -StringRef StringAttr::getValue() const { return getImpl()->value; } - -//===----------------------------------------------------------------------===// -// TypeAttr -//===----------------------------------------------------------------------===// - -TypeAttr TypeAttr::get(Type value) { - return Base::get(value.getContext(), value); -} - -Type TypeAttr::getValue() const { return getImpl()->value; } - -//===----------------------------------------------------------------------===// -// ElementsAttr -//===----------------------------------------------------------------------===// - -ShapedType ElementsAttr::getType() const { - return Attribute::getType().cast(); -} - -/// Returns the number of elements held by this attribute. -int64_t ElementsAttr::getNumElements() const { - return getType().getNumElements(); -} - -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute ElementsAttr::getValue(ArrayRef index) const { - if (auto denseAttr = dyn_cast()) - return denseAttr.getValue(index); - if (auto opaqueAttr = dyn_cast()) - return opaqueAttr.getValue(index); - return cast().getValue(index); -} - -/// Return if the given 'index' refers to a valid element in this attribute. -bool ElementsAttr::isValidIndex(ArrayRef index) const { - auto type = getType(); - - // Verify that the rank of the indices matches the held type. - auto rank = type.getRank(); - if (rank != static_cast(index.size())) - return false; - - // Verify that all of the indices are within the shape dimensions. - auto shape = type.getShape(); - return llvm::all_of(llvm::seq(0, rank), [&](int i) { - return static_cast(index[i]) < shape[i]; - }); -} - -ElementsAttr -ElementsAttr::mapValues(Type newElementType, - function_ref mapping) const { - if (auto intOrFpAttr = dyn_cast()) - return intOrFpAttr.mapValues(newElementType, mapping); - llvm_unreachable("unsupported ElementsAttr subtype"); -} - -ElementsAttr -ElementsAttr::mapValues(Type newElementType, - function_ref mapping) const { - if (auto intOrFpAttr = dyn_cast()) - return intOrFpAttr.mapValues(newElementType, mapping); - llvm_unreachable("unsupported ElementsAttr subtype"); -} - -/// Method for support type inquiry through isa, cast and dyn_cast. -bool ElementsAttr::classof(Attribute attr) { - return attr.isa(); -} - -/// Returns the 1 dimensional flattened row-major index from the given -/// multi-dimensional index. -uint64_t ElementsAttr::getFlattenedIndex(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // Reduce the provided multidimensional index into a flattended 1D row-major - // index. - auto rank = type.getRank(); - auto shape = type.getShape(); - uint64_t valueIndex = 0; - uint64_t dimMultiplier = 1; - for (int i = rank - 1; i >= 0; --i) { - valueIndex += index[i] * dimMultiplier; - dimMultiplier *= shape[i]; - } - return valueIndex; -} - -//===----------------------------------------------------------------------===// -// DenseElementsAttr Utilities -//===----------------------------------------------------------------------===// - -/// Get the bitwidth of a dense element type within the buffer. -/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. -static size_t getDenseElementStorageWidth(size_t origWidth) { - return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); -} -static size_t getDenseElementStorageWidth(Type elementType) { - return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); -} - -/// Set a bit to a specific value. -static void setBit(char *rawData, size_t bitPos, bool value) { - if (value) - rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); - else - rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); -} - -/// Return the value of the specified bit. -static bool getBit(const char *rawData, size_t bitPos) { - return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; -} - -/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for -/// BE format. -static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, - char *result) { - assert(llvm::support::endian::system_endianness() == // NOLINT - llvm::support::endianness::big); // NOLINT - assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); - - // Copy the words filled with data. - // For example, when `value` has 2 words, the first word is filled with data. - // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| - size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; - std::copy_n(reinterpret_cast(value.getRawData()), - numFilledWords, result); - // Convert last word of APInt to LE format and store it in char - // array(`valueLE`). - // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| - size_t lastWordPos = numFilledWords; - SmallVector valueLE(APInt::APINT_WORD_SIZE); - DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( - reinterpret_cast(value.getRawData()) + lastWordPos, - valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); - // Extract actual APInt data from `valueLE`, convert endianness to BE format, - // and store it in `result`. - // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| - DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( - valueLE.begin(), result + lastWordPos, - (numBytes - lastWordPos) * CHAR_BIT, 1); -} - -/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE -/// format. -static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, - APInt &result) { - assert(llvm::support::endian::system_endianness() == // NOLINT - llvm::support::endianness::big); // NOLINT - assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes); - - // Copy the data that fills the word of `result` from `inArray`. - // For example, when `result` has 2 words, the first word will be filled with - // data. So, the first 8 bytes are copied from `inArray` here. - // `inArray` (10 bytes, BE): |abcdefgh|ij| - // ==> `result` (2 words, BE): |abcdefgh|--------| - size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; - std::copy_n( - inArray, numFilledWords, - const_cast(reinterpret_cast(result.getRawData()))); - - // Convert array data which will be last word of `result` to LE format, and - // store it in char array(`inArrayLE`). - // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| - size_t lastWordPos = numFilledWords; - SmallVector inArrayLE(APInt::APINT_WORD_SIZE); - DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( - inArray + lastWordPos, inArrayLE.begin(), - (numBytes - lastWordPos) * CHAR_BIT, 1); - - // Convert `inArrayLE` to BE format, and store it in last word of `result`. - // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| - DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( - inArrayLE.begin(), - const_cast(reinterpret_cast(result.getRawData())) + - lastWordPos, - APInt::APINT_BITS_PER_WORD, 1); -} - -/// Writes value to the bit position `bitPos` in array `rawData`. -static void writeBits(char *rawData, size_t bitPos, APInt value) { - size_t bitWidth = value.getBitWidth(); - - // If the bitwidth is 1 we just toggle the specific bit. - if (bitWidth == 1) - return setBit(rawData, bitPos, value.isOneValue()); - - // Otherwise, the bit position is guaranteed to be byte aligned. - assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); - if (llvm::support::endian::system_endianness() == - llvm::support::endianness::big) { - // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. - // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't - // work correctly in BE format. - // ex. `value` (2 words including 10 bytes) - // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| - copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT), - rawData + (bitPos / CHAR_BIT)); - } else { - std::copy_n(reinterpret_cast(value.getRawData()), - llvm::divideCeil(bitWidth, CHAR_BIT), - rawData + (bitPos / CHAR_BIT)); - } -} - -/// Reads the next `bitWidth` bits from the bit position `bitPos` in array -/// `rawData`. -static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { - // Handle a boolean bit position. - if (bitWidth == 1) - return APInt(1, getBit(rawData, bitPos) ? 1 : 0); - - // Otherwise, the bit position must be 8-bit aligned. - assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); - APInt result(bitWidth, 0); - if (llvm::support::endian::system_endianness() == - llvm::support::endianness::big) { - // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. - // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't - // work correctly in BE format. - // ex. `result` (2 words including 10 bytes) - // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function - copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT), - llvm::divideCeil(bitWidth, CHAR_BIT), result); - } else { - std::copy_n(rawData + (bitPos / CHAR_BIT), - llvm::divideCeil(bitWidth, CHAR_BIT), - const_cast( - reinterpret_cast(result.getRawData()))); - } - return result; -} - -/// Returns true if 'values' corresponds to a splat, i.e. one element, or has -/// the same element count as 'type'. -template -static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { - return (values.size() == 1) || - (type.getNumElements() == static_cast(values.size())); -} - -//===----------------------------------------------------------------------===// -// DenseElementsAttr Iterators -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// AttributeElementIterator - -DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( - DenseElementsAttr attr, size_t index) - : llvm::indexed_accessor_iterator( - attr.getAsOpaquePointer(), index) {} - -Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { - auto owner = getFromOpaquePointer(base).cast(); - Type eltTy = owner.getType().getElementType(); - if (auto intEltTy = eltTy.dyn_cast()) - return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (eltTy.isa()) - return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); - if (auto floatEltTy = eltTy.dyn_cast()) { - IntElementIterator intIt(owner, index); - FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); - return FloatAttr::get(eltTy, *floatIt); - } - if (owner.isa()) { - ArrayRef vals = owner.getRawStringData(); - return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); - } - llvm_unreachable("unexpected element type"); -} - -//===----------------------------------------------------------------------===// -// BoolElementIterator - -DenseElementsAttr::BoolElementIterator::BoolElementIterator( - DenseElementsAttr attr, size_t dataIndex) - : DenseElementIndexedIteratorImpl( - attr.getRawData().data(), attr.isSplat(), dataIndex) {} - -bool DenseElementsAttr::BoolElementIterator::operator*() const { - return getBit(getData(), getDataIndex()); -} - -//===----------------------------------------------------------------------===// -// IntElementIterator - -DenseElementsAttr::IntElementIterator::IntElementIterator( - DenseElementsAttr attr, size_t dataIndex) - : DenseElementIndexedIteratorImpl( - attr.getRawData().data(), attr.isSplat(), dataIndex), - bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} - -APInt DenseElementsAttr::IntElementIterator::operator*() const { - return readBits(getData(), - getDataIndex() * getDenseElementStorageWidth(bitWidth), - bitWidth); -} - -//===----------------------------------------------------------------------===// -// ComplexIntElementIterator - -DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( - DenseElementsAttr attr, size_t dataIndex) - : DenseElementIndexedIteratorImpl, std::complex, - std::complex>( - attr.getRawData().data(), attr.isSplat(), dataIndex) { - auto complexType = attr.getType().getElementType().cast(); - bitWidth = getDenseElementBitWidth(complexType.getElementType()); -} - -std::complex -DenseElementsAttr::ComplexIntElementIterator::operator*() const { - size_t storageWidth = getDenseElementStorageWidth(bitWidth); - size_t offset = getDataIndex() * storageWidth * 2; - return {readBits(getData(), offset, bitWidth), - readBits(getData(), offset + storageWidth, bitWidth)}; -} - -//===----------------------------------------------------------------------===// -// FloatElementIterator - -DenseElementsAttr::FloatElementIterator::FloatElementIterator( - const llvm::fltSemantics &smt, IntElementIterator it) - : llvm::mapped_iterator>( - it, [&](const APInt &val) { return APFloat(smt, val); }) {} - -//===----------------------------------------------------------------------===// -// ComplexFloatElementIterator - -DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator( - const llvm::fltSemantics &smt, ComplexIntElementIterator it) - : llvm::mapped_iterator< - ComplexIntElementIterator, - std::function(const std::complex &)>>( - it, [&](const std::complex &val) -> std::complex { - return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; - }) {} - -//===----------------------------------------------------------------------===// -// DenseElementsAttr -//===----------------------------------------------------------------------===// - -/// Method for support type inquiry through isa, cast and dyn_cast. -bool DenseElementsAttr::classof(Attribute attr) { - return attr.isa(); -} - -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(hasSameElementsOrSplat(type, values)); - - // If the element type is not based on int/float/index, assume it is a string - // type. - auto eltType = type.getElementType(); - if (!type.getElementType().isIntOrIndexOrFloat()) { - SmallVector stringValues; - stringValues.reserve(values.size()); - for (Attribute attr : values) { - assert(attr.isa() && - "expected string value for non integer/index/float element"); - stringValues.push_back(attr.cast().getValue()); - } - return get(type, stringValues); - } - - // Otherwise, get the raw storage width to use for the allocation. - size_t bitWidth = getDenseElementBitWidth(eltType); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - - // Compress the attribute values into a character buffer. - SmallVector data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * - values.size()); - APInt intVal; - for (unsigned i = 0, e = values.size(); i < e; ++i) { - assert(eltType == values[i].getType() && - "expected attribute value to have element type"); - if (eltType.isa()) - intVal = values[i].cast().getValue().bitcastToAPInt(); - else if (eltType.isa()) - intVal = values[i].cast().getValue(); - else - llvm_unreachable("unexpected element type"); - - assert(intVal.getBitWidth() == bitWidth && - "expected value to have same bitwidth as element type"); - writeBits(data.data(), i * storageBitWidth, intVal); - } - return DenseIntOrFPElementsAttr::getRaw(type, data, - /*isSplat=*/(values.size() == 1)); -} - -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(hasSameElementsOrSplat(type, values)); - assert(type.getElementType().isInteger(1)); - - std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); - for (int i = 0, e = values.size(); i != e; ++i) - setBit(buff.data(), i, values[i]); - return DenseIntOrFPElementsAttr::getRaw(type, buff, - /*isSplat=*/(values.size() == 1)); -} - -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(!type.getElementType().isIntOrFloat()); - return DenseStringElementsAttr::get(type, values); -} - -/// Constructs a dense integer elements attribute from an array of APInt -/// values. Each APInt value is expected to have the same bitwidth as the -/// element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(type.getElementType().isIntOrIndex()); - assert(hasSameElementsOrSplat(type, values)); - size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, - /*isSplat=*/(values.size() == 1)); -} -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); - assert(hasSameElementsOrSplat(type, values)); - size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; - ArrayRef intVals(reinterpret_cast(values.data()), - values.size() * 2); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, - /*isSplat=*/(values.size() == 1)); -} - -// Constructs a dense float elements attribute from an array of APFloat -// values. Each APFloat value is expected to have the same bitwidth as the -// element type of 'type'. -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef values) { - assert(type.getElementType().isa()); - assert(hasSameElementsOrSplat(type, values)); - size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, - /*isSplat=*/(values.size() == 1)); -} -DenseElementsAttr -DenseElementsAttr::get(ShapedType type, - ArrayRef> values) { - ComplexType complex = type.getElementType().cast(); - assert(complex.getElementType().isa()); - assert(hasSameElementsOrSplat(type, values)); - ArrayRef apVals(reinterpret_cast(values.data()), - values.size() * 2); - size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; - return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, - /*isSplat=*/(values.size() == 1)); -} - -/// Construct a dense elements attribute from a raw buffer representing the -/// data for this attribute. Users should generally not use this methods as -/// the expected buffer format may not be a form the user expects. -DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, - ArrayRef rawBuffer, - bool isSplatBuffer) { - return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); -} - -/// Returns true if the given buffer is a valid raw buffer for the given type. -bool DenseElementsAttr::isValidRawBuffer(ShapedType type, - ArrayRef rawBuffer, - bool &detectedSplat) { - size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); - size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; - - // Storage width of 1 is special as it is packed by the bit. - if (storageWidth == 1) { - // Check for a splat, or a buffer equal to the number of elements. - if ((detectedSplat = rawBuffer.size() == 1)) - return true; - return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); - } - // All other types are 8-bit aligned. - if ((detectedSplat = rawBufferWidth == storageWidth)) - return true; - return rawBufferWidth == (storageWidth * type.getNumElements()); -} - -/// Check the information for a C++ data type, check if this type is valid for -/// the current attribute. This method is used to verify specific type -/// invariants that the templatized 'getValues' method cannot. -static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, - bool isSigned) { - // Make sure that the data element size is the same as the type element width. - if (getDenseElementBitWidth(type) != - static_cast(dataEltSize * CHAR_BIT)) - return false; - - // Check that the element type is either float or integer or index. - if (!isInt) - return type.isa(); - if (type.isIndex()) - return true; - - auto intType = type.dyn_cast(); - if (!intType) - return false; - - // Make sure signedness semantics is consistent. - if (intType.isSignless()) - return true; - return intType.isSigned() ? isSigned : !isSigned; -} - -/// Defaults down the subclass implementation. -DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, - ArrayRef data, - int64_t dataEltSize, - bool isInt, bool isSigned) { - return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, - isSigned); -} -DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, - ArrayRef data, - int64_t dataEltSize, - bool isInt, - bool isSigned) { - return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, - isInt, isSigned); -} - -/// A method used to verify specific type invariants that the templatized 'get' -/// method cannot. -bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, - bool isSigned) const { - return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, - isSigned); -} - -/// Check the information for a C++ data type, check if this type is valid for -/// the current attribute. -bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, - bool isSigned) const { - return ::isValidIntOrFloat( - getType().getElementType().cast().getElementType(), - dataEltSize / 2, isInt, isSigned); -} - -/// Returns true if this attribute corresponds to a splat, i.e. if all element -/// values are the same. -bool DenseElementsAttr::isSplat() const { - return static_cast(impl)->isSplat; -} - -/// Return the held element values as a range of Attributes. -auto DenseElementsAttr::getAttributeValues() const - -> llvm::iterator_range { - return {attr_value_begin(), attr_value_end()}; -} -auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { - return AttributeElementIterator(*this, 0); -} -auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { - return AttributeElementIterator(*this, getNumElements()); -} - -/// Return the held element values as a range of bool. The element type of -/// this attribute must be of integer type of bitwidth 1. -auto DenseElementsAttr::getBoolValues() const - -> llvm::iterator_range { - auto eltType = getType().getElementType().dyn_cast(); - assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); - (void)eltType; - return {BoolElementIterator(*this, 0), - BoolElementIterator(*this, getNumElements())}; -} - -/// Return the held element values as a range of APInts. The element type of -/// this attribute must be of integer type. -auto DenseElementsAttr::getIntValues() const - -> llvm::iterator_range { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return {raw_int_begin(), raw_int_end()}; -} -auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_begin(); -} -auto DenseElementsAttr::int_value_end() const -> IntElementIterator { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_end(); -} -auto DenseElementsAttr::getComplexIntValues() const - -> llvm::iterator_range { - Type eltTy = getType().getElementType().cast().getElementType(); - (void)eltTy; - assert(eltTy.isa() && "expected complex integral type"); - return {ComplexIntElementIterator(*this, 0), - ComplexIntElementIterator(*this, getNumElements())}; -} - -/// Return the held element values as a range of APFloat. The element type of -/// this attribute must be of float type. -auto DenseElementsAttr::getFloatValues() const - -> llvm::iterator_range { - auto elementType = getType().getElementType().cast(); - const auto &elementSemantics = elementType.getFloatSemantics(); - return {FloatElementIterator(elementSemantics, raw_int_begin()), - FloatElementIterator(elementSemantics, raw_int_end())}; -} -auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { - return getFloatValues().begin(); -} -auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { - return getFloatValues().end(); -} -auto DenseElementsAttr::getComplexFloatValues() const - -> llvm::iterator_range { - Type eltTy = getType().getElementType().cast().getElementType(); - assert(eltTy.isa() && "expected complex float type"); - const auto &semantics = eltTy.cast().getFloatSemantics(); - return {{semantics, {*this, 0}}, - {semantics, {*this, static_cast(getNumElements())}}}; -} - -/// Return the raw storage data held by this attribute. -ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; -} - -ArrayRef DenseElementsAttr::getRawStringData() const { - return static_cast(impl)->data; -} - -/// Return a new DenseElementsAttr that has the same data as the current -/// attribute, but has been reshaped to 'newType'. The new type must have the -/// same total number of elements as well as element type. -DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { - ShapedType curType = getType(); - if (curType == newType) - return *this; - - (void)curType; - assert(newType.getElementType() == curType.getElementType() && - "expected the same element type"); - assert(newType.getNumElements() == curType.getNumElements() && - "expected the same number of elements"); - return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); -} - -DenseElementsAttr -DenseElementsAttr::mapValues(Type newElementType, - function_ref mapping) const { - return cast().mapValues(newElementType, mapping); -} - -DenseElementsAttr DenseElementsAttr::mapValues( - Type newElementType, function_ref mapping) const { - return cast().mapValues(newElementType, mapping); -} - -//===----------------------------------------------------------------------===// -// DenseStringElementsAttr -//===----------------------------------------------------------------------===// - -DenseStringElementsAttr -DenseStringElementsAttr::get(ShapedType type, ArrayRef values) { - return Base::get(type.getContext(), type, values, (values.size() == 1)); -} - -//===----------------------------------------------------------------------===// -// DenseIntOrFPElementsAttr -//===----------------------------------------------------------------------===// - -/// Utility method to write a range of APInt values to a buffer. -template -static void writeAPIntsToBuffer(size_t storageWidth, std::vector &data, - APRangeT &&values) { - data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); - size_t offset = 0; - for (auto it = values.begin(), e = values.end(); it != e; - ++it, offset += storageWidth) { - assert((*it).getBitWidth() <= storageWidth); - writeBits(data.data(), offset, *it); - } -} - -/// Constructs a dense elements attribute from an array of raw APFloat values. -/// Each APFloat value is expected to have the same bitwidth as the element -/// type of 'type'. 'type' must be a vector or tensor with static shape. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, - size_t storageWidth, - ArrayRef values, - bool isSplat) { - std::vector data; - auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; - writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); - return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); -} - -/// Constructs a dense elements attribute from an array of raw APInt values. -/// Each APInt value is expected to have the same bitwidth as the element type -/// of 'type'. -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, - size_t storageWidth, - ArrayRef values, - bool isSplat) { - std::vector data; - writeAPIntsToBuffer(storageWidth, data, values); - return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); -} - -DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, - ArrayRef data, - bool isSplat) { - assert((type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), type, data, isSplat); -} - -/// Overload of the raw 'get' method that asserts that the given type is of -/// complex type. This method is used to verify type invariants that the -/// templatized 'get' method cannot. -DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, - ArrayRef data, - int64_t dataEltSize, - bool isInt, - bool isSigned) { - assert(::isValidIntOrFloat( - type.getElementType().cast().getElementType(), - dataEltSize / 2, isInt, isSigned)); - - int64_t numElements = data.size() / dataEltSize; - assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); -} - -/// Overload of the 'getRaw' method that asserts that the given type is of -/// integer type. This method is used to verify type invariants that the -/// templatized 'get' method cannot. -DenseElementsAttr -DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned) { - assert( - ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); - - int64_t numElements = data.size() / dataEltSize; - assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); -} - -void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( - const char *inRawData, char *outRawData, size_t elementBitWidth, - size_t numElements) { - using llvm::support::ulittle16_t; - using llvm::support::ulittle32_t; - using llvm::support::ulittle64_t; - - assert(llvm::support::endian::system_endianness() == // NOLINT - llvm::support::endianness::big); // NOLINT - // NOLINT to avoid warning message about replacing by static_assert() - - // Following std::copy_n always converts endianness on BE machine. - switch (elementBitWidth) { - case 16: { - const ulittle16_t *inRawDataPos = - reinterpret_cast(inRawData); - uint16_t *outDataPos = reinterpret_cast(outRawData); - std::copy_n(inRawDataPos, numElements, outDataPos); - break; - } - case 32: { - const ulittle32_t *inRawDataPos = - reinterpret_cast(inRawData); - uint32_t *outDataPos = reinterpret_cast(outRawData); - std::copy_n(inRawDataPos, numElements, outDataPos); - break; - } - case 64: { - const ulittle64_t *inRawDataPos = - reinterpret_cast(inRawData); - uint64_t *outDataPos = reinterpret_cast(outRawData); - std::copy_n(inRawDataPos, numElements, outDataPos); - break; - } - default: { - size_t nBytes = elementBitWidth / CHAR_BIT; - for (size_t i = 0; i < nBytes; i++) - std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); - break; - } - } -} - -void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( - ArrayRef inRawData, MutableArrayRef outRawData, - ShapedType type) { - size_t numElements = type.getNumElements(); - Type elementType = type.getElementType(); - if (ComplexType complexTy = elementType.dyn_cast()) { - elementType = complexTy.getElementType(); - numElements = numElements * 2; - } - size_t elementBitWidth = getDenseElementStorageWidth(elementType); - assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT && - inRawData.size() <= outRawData.size()); - convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), - elementBitWidth, numElements); -} - -//===----------------------------------------------------------------------===// -// DenseFPElementsAttr -//===----------------------------------------------------------------------===// - -template -static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, - Type newElementType, - llvm::SmallVectorImpl &data) { - size_t bitWidth = getDenseElementBitWidth(newElementType); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - - ShapedType newArrayType; - if (inType.isa()) - newArrayType = RankedTensorType::get(inType.getShape(), newElementType); - else if (inType.isa()) - newArrayType = RankedTensorType::get(inType.getShape(), newElementType); - else if (inType.isa()) - newArrayType = VectorType::get(inType.getShape(), newElementType); - else - assert(newArrayType && "Unhandled tensor type"); - - size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); - data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); - - // Functor used to process a single element value of the attribute. - auto processElt = [&](decltype(*attr.begin()) value, size_t index) { - auto newInt = mapping(value); - assert(newInt.getBitWidth() == bitWidth); - writeBits(data.data(), index * storageBitWidth, newInt); - }; - - // Check for the splat case. - if (attr.isSplat()) { - processElt(*attr.begin(), /*index=*/0); - return newArrayType; - } - - // Otherwise, process all of the element values. - uint64_t elementIdx = 0; - for (auto value : attr) - processElt(value, elementIdx++); - return newArrayType; -} - -DenseElementsAttr DenseFPElementsAttr::mapValues( - Type newElementType, function_ref mapping) const { - llvm::SmallVector elementData; - auto newArrayType = - mappingHelper(mapping, *this, getType(), newElementType, elementData); - - return getRaw(newArrayType, elementData, isSplat()); -} - -/// Method for supporting type inquiry through isa, cast and dyn_cast. -bool DenseFPElementsAttr::classof(Attribute attr) { - return attr.isa() && - attr.getType().cast().getElementType().isa(); -} - -//===----------------------------------------------------------------------===// -// DenseIntElementsAttr -//===----------------------------------------------------------------------===// - -DenseElementsAttr DenseIntElementsAttr::mapValues( - Type newElementType, function_ref mapping) const { - llvm::SmallVector elementData; - auto newArrayType = - mappingHelper(mapping, *this, getType(), newElementType, elementData); - - return getRaw(newArrayType, elementData, isSplat()); -} - -/// Method for supporting type inquiry through isa, cast and dyn_cast. -bool DenseIntElementsAttr::classof(Attribute attr) { - return attr.isa() && - attr.getType().cast().getElementType().isIntOrIndex(); -} - -//===----------------------------------------------------------------------===// -// OpaqueElementsAttr -//===----------------------------------------------------------------------===// - -OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, - StringRef bytes) { - assert(TensorType::isValidElementType(type.getElementType()) && - "Input element type should be a valid tensor element type"); - return Base::get(type.getContext(), type, dialect, bytes); -} - -StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } - -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - return Attribute(); -} - -Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } - -bool OpaqueElementsAttr::decode(ElementsAttr &result) { - auto *d = getDialect(); - if (!d) - return true; - auto *interface = - d->getRegisteredInterface(); - if (!interface) - return true; - return failed(interface->decode(*this, result)); -} - -//===----------------------------------------------------------------------===// -// SparseElementsAttr -//===----------------------------------------------------------------------===// - -SparseElementsAttr SparseElementsAttr::get(ShapedType type, - DenseElementsAttr indices, - DenseElementsAttr values) { - assert(indices.getType().getElementType().isInteger(64) && - "expected sparse indices to be 64-bit integer values"); - assert((type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), type, - indices.cast(), values); -} - -DenseIntElementsAttr SparseElementsAttr::getIndices() const { - return getImpl()->indices; -} - -DenseElementsAttr SparseElementsAttr::getValues() const { - return getImpl()->values; -} - -/// Return the value of the element at the given index. -Attribute SparseElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // The sparse indices are 64-bit integers, so we can reinterpret the raw data - // as a 1-D index array. - auto sparseIndices = getIndices(); - auto sparseIndexValues = sparseIndices.getValues(); - - // Check to see if the indices are a splat. - if (sparseIndices.isSplat()) { - // If the index is also not a splat of the index value, we know that the - // value is zero. - auto splatIndex = *sparseIndexValues.begin(); - if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) - return getZeroAttr(); - - // If the indices are a splat, we also expect the values to be a splat. - assert(getValues().isSplat() && "expected splat values"); - return getValues().getSplatValue(); - } - - // Build a mapping between known indices and the offset of the stored element. - llvm::SmallDenseMap, size_t> mappedIndices; - auto numSparseIndices = sparseIndices.getType().getDimSize(0); - size_t rank = type.getRank(); - for (size_t i = 0, e = numSparseIndices; i != e; ++i) - mappedIndices.try_emplace( - {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); - - // Look for the provided index key within the mapped indices. If the provided - // index is not found, then return a zero attribute. - auto it = mappedIndices.find(index); - if (it == mappedIndices.end()) - return getZeroAttr(); - - // Otherwise, return the held sparse value element. - return getValues().getValue(it->second); -} - -/// Get a zero APFloat for the given sparse attribute. -APFloat SparseElementsAttr::getZeroAPFloat() const { - auto eltType = getType().getElementType().cast(); - return APFloat(eltType.getFloatSemantics()); -} - -/// Get a zero APInt for the given sparse attribute. -APInt SparseElementsAttr::getZeroAPInt() const { - auto eltType = getType().getElementType().cast(); - return APInt::getNullValue(eltType.getWidth()); -} - -/// Get a zero attribute for the given attribute type. -Attribute SparseElementsAttr::getZeroAttr() const { - auto eltType = getType().getElementType(); - - // Handle floating point elements. - if (eltType.isa()) - return FloatAttr::get(eltType, 0); - - // Otherwise, this is an integer. - // TODO: Handle StringAttr here. - return IntegerAttr::get(eltType, 0); -} - -/// Flatten, and return, all of the sparse indices in this attribute in -/// row-major order. -std::vector SparseElementsAttr::getFlattenedSparseIndices() const { - std::vector flatSparseIndices; - - // The sparse indices are 64-bit integers, so we can reinterpret the raw data - // as a 1-D index array. - auto sparseIndices = getIndices(); - auto sparseIndexValues = sparseIndices.getValues(); - if (sparseIndices.isSplat()) { - SmallVector indices(getType().getRank(), - *sparseIndexValues.begin()); - flatSparseIndices.push_back(getFlattenedIndex(indices)); - return flatSparseIndices; - } - - // Otherwise, reinterpret each index as an ArrayRef when flattening. - auto numSparseIndices = sparseIndices.getType().getDimSize(0); - size_t rank = getType().getRank(); - for (size_t i = 0, e = numSparseIndices; i != e; ++i) - flatSparseIndices.push_back(getFlattenedIndex( - {&*std::next(sparseIndexValues.begin(), i * rank), rank})); - return flatSparseIndices; -} - -//===----------------------------------------------------------------------===// -// MutableDictionaryAttr -//===----------------------------------------------------------------------===// - -MutableDictionaryAttr::MutableDictionaryAttr( - ArrayRef attributes) { - setAttrs(attributes); -} - -/// Return the underlying dictionary attribute. -DictionaryAttr -MutableDictionaryAttr::getDictionary(MLIRContext *context) const { - // Construct empty DictionaryAttr if needed. - if (!attrs) - return DictionaryAttr::get({}, context); - return attrs; -} - -ArrayRef MutableDictionaryAttr::getAttrs() const { - return attrs ? attrs.getValue() : llvm::None; -} - -/// Replace the held attributes with ones provided in 'newAttrs'. -void MutableDictionaryAttr::setAttrs(ArrayRef attributes) { - // Don't create an attribute list if there are no attributes. - if (attributes.empty()) - attrs = nullptr; - else - attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); -} - -/// Return the specified attribute if present, null otherwise. -Attribute MutableDictionaryAttr::get(StringRef name) const { - return attrs ? attrs.get(name) : nullptr; -} - -/// Return the specified attribute if present, null otherwise. -Attribute MutableDictionaryAttr::get(Identifier name) const { - return attrs ? attrs.get(name) : nullptr; -} - -/// Return the specified named attribute if present, None otherwise. -Optional MutableDictionaryAttr::getNamed(StringRef name) const { - return attrs ? attrs.getNamed(name) : Optional(); -} -Optional -MutableDictionaryAttr::getNamed(Identifier name) const { - return attrs ? attrs.getNamed(name) : Optional(); -} - -/// If the an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -void MutableDictionaryAttr::set(Identifier name, Attribute value) { - assert(value && "attributes may never be null"); - - // Look for an existing value for the given name, and set it in-place. - ArrayRef values = getAttrs(); - const auto *it = llvm::find_if( - values, [name](NamedAttribute attr) { return attr.first == name; }); - if (it != values.end()) { - // Bail out early if the value is the same as what we already have. - if (it->second == value) - return; - - SmallVector newAttrs(values.begin(), values.end()); - newAttrs[it - values.begin()].second = value; - attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); - return; - } - - // Otherwise, insert the new attribute into its sorted position. - it = llvm::lower_bound(values, name); - SmallVector newAttrs; - newAttrs.reserve(values.size() + 1); - newAttrs.append(values.begin(), it); - newAttrs.push_back({name, value}); - newAttrs.append(it, values.end()); - attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext()); -} - -/// Remove the attribute with the specified name if it exists. The return -/// value indicates whether the attribute was present or not. -auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult { - auto origAttrs = getAttrs(); - for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { - if (origAttrs[i].first == name) { - // Handle the simple case of removing the only attribute in the list. - if (e == 1) { - attrs = nullptr; - return RemoveResult::Removed; - } - - SmallVector newAttrs; - newAttrs.reserve(origAttrs.size() - 1); - newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); - newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = DictionaryAttr::getWithSorted(newAttrs, - newAttrs[0].second.getContext()); - return RemoveResult::Removed; - } - } - return RemoveResult::NotFound; -} - bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { return strcmp(lhs.first.data(), rhs.first.data()) < 0; } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp copy from mlir/lib/IR/Attributes.cpp copy to mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1,4 +1,4 @@ -//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// +//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -22,36 +22,6 @@ using namespace mlir; using namespace mlir::detail; -//===----------------------------------------------------------------------===// -// AttributeStorage -//===----------------------------------------------------------------------===// - -AttributeStorage::AttributeStorage(Type type) - : type(type.getAsOpaquePointer()) {} -AttributeStorage::AttributeStorage() : type(nullptr) {} - -Type AttributeStorage::getType() const { - return Type::getFromOpaquePointer(type); -} -void AttributeStorage::setType(Type newType) { - type = newType.getAsOpaquePointer(); -} - -//===----------------------------------------------------------------------===// -// Attribute -//===----------------------------------------------------------------------===// - -/// Return the type of this attribute. -Type Attribute::getType() const { return impl->getType(); } - -/// Return the context this attribute belongs to. -MLIRContext *Attribute::getContext() const { return getType().getContext(); } - -/// Get the dialect this attribute is registered to. -Dialect &Attribute::getDialect() const { - return impl->getAbstractAttribute().getDialect(); -} - //===----------------------------------------------------------------------===// // AffineMapAttr //===----------------------------------------------------------------------===// @@ -1596,15 +1566,3 @@ } return RemoveResult::NotFound; } - -bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { - return strcmp(lhs.first.data(), rhs.first.data()) < 0; -} -bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { - // This is correct even when attr.first.data()[name.size()] is not a zero - // string terminator, because we only care about a less than comparison. - // This can't use memcmp, because it doesn't guarantee that it will stop - // reading both buffers if one is shorter than the other, even if there is - // a difference. - return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; -} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" using namespace mlir; diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -13,10 +13,10 @@ #include "mlir-c/IR.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/Registration.h" -#include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardDialect.h" #include @@ -732,7 +732,7 @@ strncpy(userData, data, len); } -int printStandardAttributes(MlirContext ctx) { +int printBuiltinAttributes(MlirContext ctx) { MlirAttribute floating = mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0); if (!mlirAttributeIsAFloat(floating) || @@ -1323,7 +1323,7 @@ if (printBuiltinTypes(ctx)) return 2; - if (printStandardAttributes(ctx)) + if (printBuiltinAttributes(ctx)) return 3; if (printAffineMap(ctx)) return 4; diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" #include "gtest/gtest.h"