diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -285,6 +285,8 @@ ItTy getCurrent() { return this->I; } + const FuncTy &getFunction() const { return F; } + FuncReturnTy operator*() const { return F(*this->I); } private: diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -0,0 +1,264 @@ +//===- BuiltinAttributeInterfaces.h - Builtin Attr Interfaces ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_H +#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/Any.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +class ShapedType; + +//===----------------------------------------------------------------------===// +// ElementsAttr +//===----------------------------------------------------------------------===// +namespace detail { +/// This class provides support for indexing into the element range of an +/// ElementsAttr. It is used to opaquely wrap either a contiguous range, via +/// `ElementsAttrIndexer::contiguous`, or a non-contiguous range, via +/// `ElementsAttrIndexer::nonContiguous`, A contiguous range is an array-like +/// range, where all of the elements are layed out sequentially in memory. A +/// non-contiguous range implies no contiguity, and elements may even be +/// materialized when indexing, such as the case for a mapped_range. +struct ElementsAttrIndexer { +public: + ElementsAttrIndexer() + : ElementsAttrIndexer(/*isContiguous=*/true, /*isSplat=*/true) {} + ElementsAttrIndexer(ElementsAttrIndexer &&rhs) + : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { + if (isContiguous) + conState = std::move(rhs.conState); + else + new (&nonConState) NonContiguousState(std::move(rhs.nonConState)); + } + ElementsAttrIndexer(const ElementsAttrIndexer &rhs) + : isContiguous(rhs.isContiguous), isSplat(rhs.isSplat) { + if (isContiguous) + conState = rhs.conState; + else + new (&nonConState) NonContiguousState(rhs.nonConState); + } + ~ElementsAttrIndexer() { + if (!isContiguous) + nonConState.~NonContiguousState(); + } + + /// Construct an indexer for a non-contiguous range starting at the given + /// iterator. A non-contiguous range implies no contiguity, and elements may + /// even be materialized when indexing, such as the case for a mapped_range. + template + static ElementsAttrIndexer nonContiguous(bool isSplat, IteratorT &&iterator) { + ElementsAttrIndexer indexer(/*isContiguous=*/false, isSplat); + new (&indexer.nonConState) + NonContiguousState(std::forward(iterator)); + return indexer; + } + + // Construct an indexer for a contiguous range starting at the given element + // pointer. A contiguous range is an array-like range, where all of the + // elements are layed out sequentially in memory. + template + static ElementsAttrIndexer contiguous(bool isSplat, const T *firstEltPtr) { + ElementsAttrIndexer indexer(/*isContiguous=*/true, isSplat); + new (&indexer.conState) ContiguousState(firstEltPtr); + return indexer; + } + + /// Access the element at the given index. + template T at(uint64_t index) const { + if (isSplat) + index = 0; + return isContiguous ? conState.at(index) : nonConState.at(index); + } + +private: + ElementsAttrIndexer(bool isContiguous, bool isSplat) + : isContiguous(isContiguous), isSplat(isSplat), conState(nullptr) {} + + /// This class contains all of the state necessary to index a contiguous + /// range. + class ContiguousState { + public: + ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {} + + /// Access the element at the given index. + template const T &at(uint64_t index) const { + return *(reinterpret_cast(firstEltPtr) + index); + } + + private: + const void *firstEltPtr; + }; + + /// This class contains all of the state necessary to index a non-contiguous + /// range. + class NonContiguousState { + private: + /// This class is used to represent the abstract base of an opaque iterator. + /// This allows for all iterator and element types to be completely + /// type-erased. + struct OpaqueIteratorBase { + virtual ~OpaqueIteratorBase() {} + virtual std::unique_ptr clone() const = 0; + }; + /// This class is used to represent the abstract base of an opaque iterator + /// that iterates over elements of type `T`. This allows for all iterator + /// types to be completely type-erased. + template + struct OpaqueIteratorValueBase : public OpaqueIteratorBase { + virtual T at(uint64_t index) = 0; + }; + /// This class is used to represent an opaque handle to an iterator of type + /// `IteratorT` that iterates over elements of type `T`. + template + struct OpaqueIterator : public OpaqueIteratorValueBase { + template + static void isMappedIteratorTestFn( + llvm::mapped_iterator) {} + template + using is_mapped_iterator = + decltype(isMappedIteratorTestFn(std::declval())); + template + using detect_is_mapped_iterator = + llvm::is_detected; + + /// Access the element within the iterator at the given index. + template + static std::enable_if_t::value, T> + atImpl(ItT &&it, uint64_t index) { + return *std::next(it, index); + } + template + static std::enable_if_t::value, T> + atImpl(ItT &&it, uint64_t index) { + // Special case mapped_iterator to avoid copying the function. + return it.getFunction()(*std::next(it.getCurrent(), index)); + } + + public: + template + OpaqueIterator(U &&iterator) : iterator(std::forward(iterator)) {} + std::unique_ptr clone() const final { + return std::make_unique>(iterator); + } + + /// Access the element at the given index. + T at(uint64_t index) final { return atImpl(iterator, index); } + + private: + IteratorT iterator; + }; + + public: + /// Construct the state with the given iterator type. + template ())>> + NonContiguousState(IteratorT iterator) + : iterator(std::make_unique>(iterator)) {} + NonContiguousState(const NonContiguousState &other) + : iterator(other.iterator->clone()) {} + NonContiguousState(NonContiguousState &&other) = default; + + /// Access the element at the given index. + template T at(uint64_t index) const { + auto *valueIt = static_cast *>(iterator.get()); + return valueIt->at(index); + } + + /// The opaque iterator state. + std::unique_ptr iterator; + }; + + /// A boolean indicating if this range is contiguous or not. + bool isContiguous; + /// A boolean indicating if this range is a splat. + bool isSplat; + /// The underlying range state. + union { + ContiguousState conState; + NonContiguousState nonConState; + }; +}; + +/// This class implements a generic iterator for ElementsAttr. +template +class ElementsAttrIterator + : public llvm::iterator_facade_base, + std::random_access_iterator_tag, T, + std::ptrdiff_t, T, T> { +public: + ElementsAttrIterator(ElementsAttrIndexer indexer, size_t dataIndex) + : indexer(std::move(indexer)), index(dataIndex) {} + + // Boilerplate iterator methods. + ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { + return index - rhs.index; + } + bool operator==(const ElementsAttrIterator &rhs) const { + return index == rhs.index; + } + bool operator<(const ElementsAttrIterator &rhs) const { + return index < rhs.index; + } + ElementsAttrIterator &operator+=(ptrdiff_t offset) { + index += offset; + return *this; + } + ElementsAttrIterator &operator-=(ptrdiff_t offset) { + index -= offset; + return *this; + } + + /// Return the value at the current iterator position. + T operator*() const { return indexer.at(index); } + +private: + ElementsAttrIndexer indexer; + ptrdiff_t index; +}; +} // namespace detail +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributeInterfaces.h.inc" + +//===----------------------------------------------------------------------===// +// ElementsAttr +//===----------------------------------------------------------------------===// + +namespace mlir { +/// Return the elements of this attribute as a value of type 'T'. +template +auto ElementsAttr::value_begin() const -> DefaultValueCheckT> { + if (Optional> iterator = try_value_begin()) + return std::move(*iterator); + llvm::errs() + << "ElementsAttr does not provide iteration facilities for type `" + << llvm::getTypeName() << "`, see attribute: " << *this << "\n"; + llvm_unreachable("invalid `T` for ElementsAttr::getValues"); +} +template +auto ElementsAttr::try_value_begin() const + -> DefaultValueCheckT>> { + FailureOr indexer = + getValuesImpl(TypeID::get()); + if (failed(indexer)) + return llvm::None; + return iterator(std::move(*indexer), 0); +} +} // end namespace mlir. + +#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_H diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -0,0 +1,430 @@ +//===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the ElementsAttr interface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ +#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// ElementsAttrInterface +//===----------------------------------------------------------------------===// + +def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface is used for attributes that contain the constant elements of + a tensor or vector type. It allows for opaquely interacting with the + elements of the underlying attribute, and most importantly allows for + accessing the element values (including iteration) in any of the C++ data + types supported by the underlying attribute. + + An attribute implementing this interface can expose the supported data types + in two steps: + + * Define the set of iterable C++ data types: + + An attribute may define the set of iterable types by providing a definition + of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`. + + - `ContiguousIterableTypesT` should contain types which can be iterated + contiguously. A contiguous range is an array-like range, such as + ArrayRef, where all of the elements are layed out sequentially in memory. + + - `NonContiguousIterableTypesT` should contain types which can not be + iterated contiguously. A non-contiguous range implies no contiguity, + whose elements may even be materialized when indexing, such as the case + for a mapped_range. + + As an example, consider an attribute that only contains i64 elements, with + the elements being stored within an ArrayRef. This attribute could + potentially define the iterable types as so: + + ```c++ + using ContiguousIterableTypesT = std::tuple; + using NonContiguousIterableTypesT = std::tuple; + ``` + + * Provide a `iterator value_begin_impl(OverloadToken) const` overload for + each iterable type + + These overloads should return an iterator to the start of the range for the + respective iterable type. Consider the example i64 elements attribute + described in the previous section. This attribute may define the + value_begin_impl overloads like so: + + ```c++ + /// Provide begin iterators for the various iterable types. + /// * uint64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + /// * APInt + auto value_begin_impl(OverloadToken) const { + return llvm::map_range(getElements(), [=](uint64_t value) { + return llvm::APInt(/*numBits=*/64, value); + }).begin(); + } + /// * Attribute + auto value_begin_impl(OverloadToken) const { + mlir::Type elementType = getType().getElementType(); + return llvm::map_range(getElements(), [=](uint64_t value) { + return mlir::IntegerAttr::get(elementType, + llvm::APInt(/*numBits=*/64, value)); + }).begin(); + } + ``` + + After the above, ElementsAttr will now be able to iterate over elements + using each of the registered iterable data types: + + ```c++ + ElementsAttr attr = myI64ElementsAttr; + + // We can access value ranges for the data types via `getValues`. + for (uint64_t value : attr.getValues()) + ...; + for (llvm::APInt value : attr.getValues()) + ...; + for (mlir::IntegerAttr value : attr.getValues()) + ...; + + // We can also access the value iterators directly. + auto it = attr.value_begin(), e = attr.value_end(); + for (; it != e; ++it) { + uint64_t value = *it; + ... + } + ``` + + ElementsAttr also supports failable access to iterators and ranges. This + allows for safely checking if the attribute supports the data type, and can + also allow for code to have fast paths for native data types. + + ```c++ + // Using `tryGetValues`, we can also safely handle when the attribute + // doesn't support the data type. + if (auto range = attr.tryGetValues()) { + for (uint64_t value : *range) + ...; + return; + } + + // We can also access the begin iterator safely, by using `try_value_begin`. + if (auto safeIt = attr.try_value_begin()) { + auto it = *safeIt, e = attr.value_end(); + for (; it != e; ++it) { + uint64_t value = *it; + ... + } + return; + } + ``` + }]; + let methods = [ + InterfaceMethod<[{ + This method returns an opaque range indexer for the given elementID, which + corresponds to a desired C++ element data type. Returns the indexer if the + attribute supports the given data type, failure otherwise. + }], + "::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl", + (ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{ + auto result = getValueImpl( + (typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID, + /*isContiguous=*/std::true_type()); + if (succeeded(result)) + return std::move(result); + + return getValueImpl( + (typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr, + elementID, /*isContiguous=*/std::false_type()); + }]>, + InterfaceMethod<[{ + Returns true if the attribute elements correspond to a splat, i.e. that + all elements of the attribute are the same value. + }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{ + // By default, only check for a single element splat. + return $_attr.getNumElements() == 1; + }]> + ]; + + string ElementsAttrInterfaceAccessors = [{ + /// Return the attribute value at the given index. The index is expected to + /// refer to a valid element. + Attribute getValue(ArrayRef index) const { + return getValue(index); + } + + /// Return the value of type 'T' at the given index, where 'T' corresponds + /// to an Attribute type. + template + std::enable_if_t::value && + std::is_base_of::value> + getValue(ArrayRef index) const { + return getValue(index).template dyn_cast_or_null(); + } + + /// Return the value of type 'T' at the given index. + template + T getValue(ArrayRef index) const { + return getFlatValue(getFlattenedIndex(index)); + } + + /// Return the number of elements held by this attribute. + int64_t size() const { return getNumElements(); } + + /// Return if the attribute holds no elements. + bool empty() const { return size() == 0; } + }]; + + let extraTraitClassDeclaration = [{ + // By default, no types are iterable. + using ContiguousIterableTypesT = std::tuple<>; + using NonContiguousIterableTypesT = std::tuple<>; + + //===------------------------------------------------------------------===// + // Accessors + //===------------------------------------------------------------------===// + + /// Return the element type of this ElementsAttr. + Type getElementType() const { + return ::mlir::ElementsAttr::getElementType($_attr); + } + + /// Returns the number of elements held by this attribute. + int64_t getNumElements() const { + return ::mlir::ElementsAttr::getNumElements($_attr); + } + + /// Return if the given 'index' refers to a valid element in this attribute. + bool isValidIndex(ArrayRef index) const { + return ::mlir::ElementsAttr::isValidIndex($_attr, index); + } + + protected: + /// Returns the 1-dimensional flattened row-major index from the given + /// multi-dimensional index. + uint64_t getFlattenedIndex(ArrayRef index) const { + return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index); + } + + //===------------------------------------------------------------------===// + // Value Iteration Internals + //===------------------------------------------------------------------===// + protected: + /// This class is used to allow specifying function overloads for different + /// types, without actually taking the types as parameters. This avoids the + /// need to build complicated SFINAE to select specific overloads. + template + struct OverloadToken {}; + + private: + /// This function unpacks the types within a given tuple and then forwards + /// on to the unwrapped variant. + template + auto getValueImpl(std::tuple *, ::mlir::TypeID elementID, + IsContiguousT isContiguous) const { + return getValueImpl(elementID, isContiguous); + } + /// Check to see if the given `elementID` matches the current type `T`. If + /// it does, build a value result using the current type. If it doesn't, + /// keep looking for the desired type. + template + auto getValueImpl(::mlir::TypeID elementID, + IsContiguousT isContiguous) const { + if (::mlir::TypeID::get() == elementID) + return buildValueResult(isContiguous); + return getValueImpl(elementID, isContiguous); + } + /// Bottom out case for no matching type. + template + ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> + getValueImpl(::mlir::TypeID, IsContiguousT) const { + return failure(); + } + + /// Build an indexer for the given type `T`, which is represented via a + /// contiguous range. + template + ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( + /*isContiguous*/std::true_type) const { + auto valueIt = $_attr.value_begin_impl(OverloadToken()); + return ::mlir::detail::ElementsAttrIndexer::contiguous( + $_attr.isSplat(), &*valueIt); + } + /// Build an indexer for the given type `T`, which is represented via a + /// non-contiguous range. + template + ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( + /*isContiguous*/std::false_type) const { + auto valueIt = $_attr.value_begin_impl(OverloadToken()); + return ::mlir::detail::ElementsAttrIndexer::nonContiguous( + $_attr.isSplat(), valueIt); + } + + public: + //===------------------------------------------------------------------===// + // Value Iteration + //===------------------------------------------------------------------===// + + /// Return an iterator to the first element of this attribute as a value of + /// type `T`. + template + auto value_begin() const { + return $_attr.value_begin_impl(OverloadToken()); + } + + /// Return the elements of this attribute as a value of type 'T'. + template + auto getValues() const { + auto beginIt = $_attr.template value_begin(); + return llvm::make_range(beginIt, std::next(beginIt, size())); + } + /// Return the value at the given flattened index. + template T getFlatValue(uint64_t index) const { + return *std::next($_attr.template value_begin(), index); + } + }] # ElementsAttrInterfaceAccessors; + + let extraClassDeclaration = [{ + template + using iterator = detail::ElementsAttrIterator; + template + using iterator_range = llvm::iterator_range>; + + //===------------------------------------------------------------------===// + // Accessors + //===------------------------------------------------------------------===// + + /// Return the type of this attribute. + ShapedType getType() const; + + /// Return the element type of this ElementsAttr. + Type getElementType() const { return getElementType(*this); } + static Type getElementType(Attribute elementsAttr); + + /// Return if the given 'index' refers to a valid element in this attribute. + bool isValidIndex(ArrayRef index) const { + return isValidIndex(*this, index); + } + static bool isValidIndex(ShapedType type, ArrayRef index); + static bool isValidIndex(Attribute elementsAttr, ArrayRef index); + + /// Return the 1 dimensional flattened row-major index from the given + /// multi-dimensional index. + uint64_t getFlattenedIndex(ArrayRef index) const { + return getFlattenedIndex(*this, index); + } + static uint64_t getFlattenedIndex(Attribute elementsAttr, + ArrayRef index); + + /// Returns the number of elements held by this attribute. + int64_t getNumElements() const { return getNumElements(*this); } + static int64_t getNumElements(Attribute elementsAttr); + + //===------------------------------------------------------------------===// + // Value Iteration + //===------------------------------------------------------------------===// + + template + using DerivedAttrValueCheckT = + typename std::enable_if_t::value && + !std::is_same::value>; + template + using DefaultValueCheckT = + typename std::enable_if_t::value || + !std::is_base_of::value, + ResultT>; + + /// Return the element of this attribute at the given index as a value of + /// type 'T'. + template + T getFlatValue(uint64_t index) const { + return *std::next(value_begin(), index); + } + + /// Return the splat value for this attribute. This asserts that the + /// attribute corresponds to a splat. + template + T getSplatValue() const { + assert(isSplat() && "expected splat attribute"); + return *value_begin(); + } + + /// Return the elements of this attribute as a value of type 'T'. + template + DefaultValueCheckT> getValues() const { + return iterator_range(value_begin(), value_end()); + } + template + DefaultValueCheckT> value_begin() const; + template + DefaultValueCheckT> value_end() const { + return iterator({}, size()); + } + + /// Return the held element values a range of T, where T is a derived + /// attribute type. + template + using DerivedAttrValueIterator = + llvm::mapped_iterator, T (*)(Attribute)>; + template + using DerivedAttrValueIteratorRange = + llvm::iterator_range>; + template > + DerivedAttrValueIteratorRange getValues() const { + auto castFn = [](Attribute attr) { return attr.template cast(); }; + return llvm::map_range(getValues(), + static_cast(castFn)); + } + template > + DerivedAttrValueIterator value_begin() const { + return getValues().begin(); + } + template > + DerivedAttrValueIterator value_end() const { + return {value_end(), nullptr}; + } + + //===------------------------------------------------------------------===// + // Failable Value Iteration + + /// If this attribute supports iterating over element values of type `T`, + /// return the iterable range. Otherwise, return llvm::None. + template + DefaultValueCheckT>> tryGetValues() const { + if (Optional> beginIt = try_value_begin()) + return iterator_range(*beginIt, value_end()); + return llvm::None; + } + template + DefaultValueCheckT>> try_value_begin() const; + + /// If this attribute supports iterating over element values of type `T`, + /// return the iterable range. Otherwise, return llvm::None. + template > + Optional> tryGetValues() const { + auto castFn = [](Attribute attr) { return attr.template cast(); }; + if (auto values = tryGetValues()) + return llvm::map_range(*values, static_cast(castFn)); + return llvm::None; + } + template > + Optional> try_value_begin() const { + if (auto values = tryGetValues()) + return values->begin(); + return llvm::None; + } + }] # ElementsAttrInterfaceAccessors; +} + +#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -9,7 +9,8 @@ #ifndef MLIR_IR_BUILTINATTRIBUTES_H #define MLIR_IR_BUILTINATTRIBUTES_H -#include "SubElementInterfaces.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include @@ -31,99 +32,8 @@ //===----------------------------------------------------------------------===// namespace detail { -template -class ElementsAttrIterator; -template -class ElementsAttrRange; -} // namespace detail - -/// A base attribute that represents a reference to a static shaped tensor or -/// vector constant. -class ElementsAttr : public Attribute { -public: - using Attribute::Attribute; - template - using iterator = detail::ElementsAttrIterator; - template - using iterator_range = detail::ElementsAttrRange; - - /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor - /// with static shape. - ShapedType getType() const; - - /// Return the element type of this ElementsAttr. - Type getElementType() const; - - /// Return the value at the given index. The index is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const; - - /// Return the value of type 'T' at the given index, where 'T' corresponds to - /// an Attribute type. - template - T getValue(ArrayRef index) const { - return getValue(index).template cast(); - } - - /// Return the elements of this attribute as a value of type 'T'. Note: - /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support - /// iteration. - template iterator_range getValues() const; - template iterator value_begin() const; - template iterator value_end() const; - - /// Return if the given 'index' refers to a valid element in this attribute. - bool isValidIndex(ArrayRef index) const; - static bool isValidIndex(ShapedType type, ArrayRef index); - - /// Returns the 1-dimensional flattened row-major index from the given - /// multi-dimensional index. - uint64_t getFlattenedIndex(ArrayRef index) const; - static uint64_t getFlattenedIndex(ShapedType type, ArrayRef index); - - /// Returns the number of elements held by this attribute. - int64_t getNumElements() const; - - /// Returns the number of elements held by this attribute. - int64_t size() const { return getNumElements(); } - - /// Returns if the number of elements held by this attribute is 0. - bool empty() const { return size() == 0; } - - /// Generates a new ElementsAttr by mapping each int value to a new - /// underlying APInt. The new values can represent either an integer or float. - /// This ElementsAttr should contain integers. - ElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Generates a new ElementsAttr by mapping each float value to a new - /// underlying APInt. The new values can represent either an integer or float. - /// This ElementsAttr should contain floats. - ElementsAttr mapValues(Type newElementType, - function_ref mapping) const; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr); -}; - -namespace detail { -/// DenseElementsAttr data is aligned to uint64_t, so this traits class is -/// necessary to interop with PointerIntPair. -class DenseElementDataPointerTypeTraits { -public: - static inline const void *getAsVoidPointer(const char *ptr) { return ptr; } - static inline const char *getFromVoidPointer(const void *ptr) { - return static_cast(ptr); - } - - // Note: We could steal more bits if the need arises. - static constexpr int NumLowBitsAvailable = 1; -}; - /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat, -using DenseIterPtrAndSplat = - llvm::PointerIntPair; +using DenseIterPtrAndSplat = std::pair; /// Impl iterator for indexed DenseElementsAttr iterators that records a data /// pointer and data index that is adjusted for the case of a splat attribute. @@ -142,12 +52,12 @@ /// Return the current index for this iterator, adjusted for the case of a /// splat. ptrdiff_t getDataIndex() const { - bool isSplat = this->base.getInt(); + bool isSplat = this->base.second; return isSplat ? 0 : this->index; } /// Return the data base pointer. - const char *getData() const { return this->base.getPointer(); } + const char *getData() const { return this->base.first; } }; /// Type trait detector that checks if a given type T is a complex type. @@ -159,9 +69,14 @@ /// An attribute that represents a reference to a dense vector or tensor object. /// -class DenseElementsAttr : public ElementsAttr { +class DenseElementsAttr : public Attribute { public: - using ElementsAttr::ElementsAttr; + using Attribute::Attribute; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } /// Type trait used to check if the given type T is a potentially valid C++ /// floating point type that can be used to access the underlying element @@ -440,7 +355,7 @@ template T getValue(ArrayRef index) const { // Skip to the element corresponding to the flattened index. - return getFlatValue(getFlattenedIndex(index)); + return getFlatValue(ElementsAttr::getFlattenedIndex(*this, index)); } /// Return the value at the given flattened index. template T getFlatValue(uint64_t index) const { @@ -678,6 +593,22 @@ /// Return the raw StringRef data held by this attribute. ArrayRef getRawStringData() const; + /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor + /// with static shape. + ShapedType getType() const; + + /// Return the element type of this DenseElementsAttr. + Type getElementType() const; + + /// Returns the number of elements held by this attribute. + int64_t getNumElements() const; + + /// Returns the number of elements held by this attribute. + int64_t size() const { return getNumElements(); } + + /// Returns if the number of elements held by this attribute is 0. + bool empty() const { return size() == 0; } + //===--------------------------------------------------------------------===// // Mutation Utilities //===--------------------------------------------------------------------===// @@ -761,7 +692,6 @@ return denseAttr && denseAttr.isSplat(); } }; - } // namespace mlir //===----------------------------------------------------------------------===// @@ -954,159 +884,6 @@ auto SparseElementsAttr::value_end() const -> iterator { return getValues().end(); } - -namespace detail { -/// This class represents a general iterator over the values of an ElementsAttr. -/// It supports all subclasses aside from OpaqueElementsAttr. -template -class ElementsAttrIterator - : public llvm::iterator_facade_base, - std::random_access_iterator_tag, T, - std::ptrdiff_t, T, T> { - // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype' - // inside of a conversion operator. - using DenseIteratorT = typename std::enable_if< - true, decltype(std::declval().value_begin())>::type; - using SparseIteratorT = SparseElementsAttr::iterator; - - /// A union containing the specific iterators for each derived attribute kind. - union Iterator { - Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {} - Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {} - Iterator() {} - ~Iterator() {} - - operator const DenseIteratorT &() const { return denseIt; } - operator const SparseIteratorT &() const { return sparseIt; } - operator DenseIteratorT &() { return denseIt; } - operator SparseIteratorT &() { return sparseIt; } - - /// An instance of a dense elements iterator. - DenseIteratorT denseIt; - /// An instance of a sparse elements iterator. - SparseIteratorT sparseIt; - }; - - /// Utility method to process a functor on each of the internal iterator - /// types. - template class ProcessFn, - typename... Args> - RetT process(Args &...args) const { - if (attr.isa()) - return ProcessFn()(args...); - if (attr.isa()) - return ProcessFn()(args...); - llvm_unreachable("unexpected attribute kind"); - } - - /// Utility functors used to generically implement the iterators methods. - template - struct PlusAssign { - void operator()(ItT &it, ptrdiff_t offset) { it += offset; } - }; - template - struct Minus { - ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } - }; - template - struct MinusAssign { - void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } - }; - template - struct Dereference { - T operator()(ItT &it) { return *it; } - }; - template - struct ConstructIter { - void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } - }; - template - struct DestructIter { - void operator()(ItT &it) { it.~ItT(); } - }; - -public: - ElementsAttrIterator(const ElementsAttrIterator &rhs) : attr(rhs.attr) { - process(it, rhs.it); - } - ~ElementsAttrIterator() { process(it); } - - /// Methods necessary to support random access iteration. - ptrdiff_t operator-(const ElementsAttrIterator &rhs) const { - assert(attr == rhs.attr && "incompatible iterators"); - return process(it, rhs.it); - } - bool operator==(const ElementsAttrIterator &rhs) const { - return rhs.attr == attr && process(it, rhs.it); - } - bool operator<(const ElementsAttrIterator &rhs) const { - assert(attr == rhs.attr && "incompatible iterators"); - return process(it, rhs.it); - } - ElementsAttrIterator &operator+=(ptrdiff_t offset) { - process(it, offset); - return *this; - } - ElementsAttrIterator &operator-=(ptrdiff_t offset) { - process(it, offset); - return *this; - } - - /// Dereference the iterator at the current index. - T operator*() { return process(it); } - -private: - template - ElementsAttrIterator(Attribute attr, IteratorT &&it) - : attr(attr), it(std::forward(it)) {} - - /// Allow accessing the constructor. - friend ElementsAttr; - - /// The parent elements attribute. - Attribute attr; - - /// A union containing the specific iterators for each derived kind. - Iterator it; -}; - -template -class ElementsAttrRange : public llvm::iterator_range> { - using llvm::iterator_range>::iterator_range; -}; -} // namespace detail - -/// Return the elements of this attribute as a value of type 'T'. -template -auto ElementsAttr::getValues() const -> iterator_range { - if (DenseElementsAttr denseAttr = dyn_cast()) { - auto values = denseAttr.getValues(); - return {iterator(*this, values.begin()), - iterator(*this, values.end())}; - } - if (SparseElementsAttr sparseAttr = dyn_cast()) { - auto values = sparseAttr.getValues(); - return {iterator(*this, values.begin()), - iterator(*this, values.end())}; - } - llvm_unreachable("unexpected attribute kind"); -} - -template auto ElementsAttr::value_begin() const -> iterator { - if (DenseElementsAttr denseAttr = dyn_cast()) - return iterator(*this, denseAttr.value_begin()); - if (SparseElementsAttr sparseAttr = dyn_cast()) - return iterator(*this, sparseAttr.value_begin()); - llvm_unreachable("unexpected attribute kind"); -} -template auto ElementsAttr::value_end() const -> iterator { - if (DenseElementsAttr denseAttr = dyn_cast()) - return iterator(*this, denseAttr.value_end()); - if (SparseElementsAttr sparseAttr = dyn_cast()) - return iterator(*this, sparseAttr.value_end()); - llvm_unreachable("unexpected attribute kind"); -} - } // end namespace mlir. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -15,6 +15,7 @@ #define BUILTIN_ATTRIBUTES include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with @@ -136,8 +137,9 @@ // DenseIntOrFPElementsAttr //===----------------------------------------------------------------------===// -def Builtin_DenseIntOrFPElementsAttr - : Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> { +def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< + "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr" + > { let summary = "An Attribute containing a dense multi-dimensional array of " "integer or floating-point values"; let description = [{ @@ -165,6 +167,42 @@ let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, "ArrayRef":$rawData); let extraClassDeclaration = [{ + using DenseElementsAttr::empty; + using DenseElementsAttr::getFlatValue; + using DenseElementsAttr::getNumElements; + using DenseElementsAttr::getValue; + using DenseElementsAttr::getValues; + using DenseElementsAttr::isSplat; + using DenseElementsAttr::size; + using DenseElementsAttr::value_begin; + + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple< + // Integer types. + uint8_t, uint16_t, uint32_t, uint64_t, + int8_t, int16_t, int32_t, int64_t, + short, unsigned short, int, unsigned, long, unsigned long, + std::complex, std::complex, std::complex, + std::complex, + std::complex, std::complex, std::complex, + std::complex, + // Float types. + float, double, std::complex, std::complex + >; + using NonContiguousIterableTypesT = std::tuple< + Attribute, + // Integer types. + APInt, bool, std::complex, + // Float types. + APFloat, std::complex + >; + + /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + template + auto value_begin_impl(OverloadToken) const { + return value_begin(); + } + /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of /// the elements of `inRawData` has `type`. If `inRawData` is little endian /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is @@ -231,8 +269,9 @@ // DenseStringElementsAttr //===----------------------------------------------------------------------===// -def Builtin_DenseStringElementsAttr - : Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> { +def Builtin_DenseStringElementsAttr : Builtin_Attr< + "DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr" + > { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; let description = [{ @@ -267,6 +306,25 @@ }]>, ]; let extraClassDeclaration = [{ + using DenseElementsAttr::empty; + using DenseElementsAttr::getFlatValue; + using DenseElementsAttr::getNumElements; + using DenseElementsAttr::getValue; + using DenseElementsAttr::getValues; + using DenseElementsAttr::isSplat; + using DenseElementsAttr::size; + using DenseElementsAttr::value_begin; + + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + using NonContiguousIterableTypesT = std::tuple; + + /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + template + auto value_begin_impl(OverloadToken) const { + return value_begin(); + } + protected: friend DenseElementsAttr; @@ -594,8 +652,9 @@ // OpaqueElementsAttr //===----------------------------------------------------------------------===// -def Builtin_OpaqueElementsAttr - : Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> { +def Builtin_OpaqueElementsAttr : Builtin_Attr< + "OpaqueElements", [ElementsAttrInterface] + > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -650,7 +709,6 @@ /// Returns false if decoding is successful. If not, returns true and leaves /// 'result' argument unspecified. bool decode(ElementsAttr &result); - }]; let genVerifyDecl = 1; let skipDefaultBuilders = 1; @@ -660,8 +718,9 @@ // SparseElementsAttr //===----------------------------------------------------------------------===// -def Builtin_SparseElementsAttr - : Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> { +def Builtin_SparseElementsAttr : Builtin_Attr< + "SparseElements", [ElementsAttrInterface] + > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -712,6 +771,33 @@ }]>, ]; let extraClassDeclaration = [{ + /// The set of data types that can be iterated by this attribute. + // FIXME: Realistically, SparseElementsAttr could use ElementsAttr for the + // value storage. This would mean dispatching to `values` when accessing + // values. For now, we just add the types that can be iterated by + // DenseElementsAttr. + using NonContiguousIterableTypesT = std::tuple< + Attribute, + // Integer types. + APInt, bool, uint8_t, uint16_t, uint32_t, uint64_t, + int8_t, int16_t, int32_t, int64_t, + short, unsigned short, int, unsigned, long, unsigned long, + std::complex, std::complex, std::complex, + std::complex, std::complex, std::complex, + std::complex, std::complex, std::complex, + // Float types. + APFloat, float, double, + std::complex, std::complex, std::complex, + // String types. + StringRef + >; + + /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + template + auto value_begin_impl(OverloadToken) const { + return value_begin(); + } + template using iterator = llvm::mapped_iterator(0, 0))::iterator, diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -7,6 +7,11 @@ mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRBuiltinAttributesIncGen) +set(LLVM_TARGET_DEFINITIONS BuiltinAttributeInterfaces.td) +mlir_tablegen(BuiltinAttributeInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(BuiltinAttributeInterfaces.cpp.inc -gen-attr-interface-defs) +add_public_tablegen_target(MLIRBuiltinAttributeInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td) mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs) diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -93,6 +93,7 @@ : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { assert((!t || impl) && "expected value to provide interface instance"); } + Interface(std::nullptr_t) : BaseType(ValueT()), impl(nullptr) {} /// Construct an interface instance from a type that implements this /// interface's trait. diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -0,0 +1,74 @@ +//===- BuiltinAttributeInterfaces.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +/// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// ElementsAttr +//===----------------------------------------------------------------------===// + +ShapedType ElementsAttr::getType() const { + return Attribute::getType().cast(); +} + +Type ElementsAttr::getElementType(Attribute elementsAttr) { + return elementsAttr.getType().cast().getElementType(); +} + +int64_t ElementsAttr::getNumElements(Attribute elementsAttr) { + return elementsAttr.getType().cast().getNumElements(); +} + +bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { + // Verify that the rank of the indices matches the held type. + int64_t rank = type.getRank(); + if (rank == 0 && index.size() == 1 && index[0] == 0) + return true; + if (rank != static_cast(index.size())) + return false; + + // Verify that all of the indices are within the shape dimensions. + ArrayRef shape = type.getShape(); + return llvm::all_of(llvm::seq(0, rank), [&](int i) { + int64_t dim = static_cast(index[i]); + return 0 <= dim && dim < shape[i]; + }); +} +bool ElementsAttr::isValidIndex(Attribute elementsAttr, + ArrayRef index) { + return isValidIndex(elementsAttr.getType().cast(), index); +} + +uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr, + ArrayRef index) { + ShapedType type = elementsAttr.getType().cast(); + assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); + + // Reduce the provided multidimensional index into a flattended 1D row-major + // index. + auto rank = type.getRank(); + auto shape = type.getShape(); + uint64_t valueIndex = 0; + uint64_t dimMultiplier = 1; + for (int i = rank - 1; i >= 0; --i) { + valueIndex += index[i] * dimMultiplier; + dimMultiplier *= shape[i]; + } + return valueIndex; +} diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -382,92 +382,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// ElementsAttr -//===----------------------------------------------------------------------===// - -ShapedType ElementsAttr::getType() const { - return Attribute::getType().cast(); -} - -Type ElementsAttr::getElementType() const { return getType().getElementType(); } - -/// Returns the number of elements held by this attribute. -int64_t ElementsAttr::getNumElements() const { - return getType().getNumElements(); -} - -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute ElementsAttr::getValue(ArrayRef index) const { - if (auto denseAttr = dyn_cast()) - return denseAttr.getValue(index); - if (auto opaqueAttr = dyn_cast()) - return opaqueAttr.getValue(index); - return cast().getValue(index); -} - -bool ElementsAttr::isValidIndex(ArrayRef index) const { - return isValidIndex(getType(), index); -} -bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { - // Verify that the rank of the indices matches the held type. - int64_t rank = type.getRank(); - if (rank == 0 && index.size() == 1 && index[0] == 0) - return true; - if (rank != static_cast(index.size())) - return false; - - // Verify that all of the indices are within the shape dimensions. - ArrayRef shape = type.getShape(); - return llvm::all_of(llvm::seq(0, rank), [&](int i) { - int64_t dim = static_cast(index[i]); - return 0 <= dim && dim < shape[i]; - }); -} - -uint64_t ElementsAttr::getFlattenedIndex(ArrayRef index) const { - return getFlattenedIndex(getType(), index); -} -uint64_t ElementsAttr::getFlattenedIndex(ShapedType type, - ArrayRef index) { - assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); - - // Reduce the provided multidimensional index into a flattended 1D row-major - // index. - auto rank = type.getRank(); - auto shape = type.getShape(); - uint64_t valueIndex = 0; - uint64_t dimMultiplier = 1; - for (int i = rank - 1; i >= 0; --i) { - valueIndex += index[i] * dimMultiplier; - dimMultiplier *= shape[i]; - } - return valueIndex; -} - -ElementsAttr -ElementsAttr::mapValues(Type newElementType, - function_ref mapping) const { - if (auto intOrFpAttr = dyn_cast()) - return intOrFpAttr.mapValues(newElementType, mapping); - llvm_unreachable("unsupported ElementsAttr subtype"); -} - -ElementsAttr -ElementsAttr::mapValues(Type newElementType, - function_ref mapping) const { - if (auto intOrFpAttr = dyn_cast()) - return intOrFpAttr.mapValues(newElementType, mapping); - llvm_unreachable("unsupported ElementsAttr subtype"); -} - -/// Method for support type inquiry through isa, cast and dyn_cast. -bool ElementsAttr::classof(Attribute attr) { - return attr.isa(); -} - //===----------------------------------------------------------------------===// // DenseElementsAttr Utilities //===----------------------------------------------------------------------===// @@ -1065,6 +979,18 @@ return cast().mapValues(newElementType, mapping); } +ShapedType DenseElementsAttr::getType() const { + return Attribute::getType().cast(); +} + +Type DenseElementsAttr::getElementType() const { + return getType().getElementType(); +} + +int64_t DenseElementsAttr::getNumElements() const { + return getType().getNumElements(); +} + //===----------------------------------------------------------------------===// // DenseIntOrFPElementsAttr //===----------------------------------------------------------------------===// @@ -1431,7 +1357,7 @@ // Verify indices shape. size_t rank = type.getRank(), indicesRank = indicesType.getRank(); if (indicesRank == 2) { - if (indicesType.getDimSize(1) != rank) + if (indicesType.getDimSize(1) != static_cast(rank)) return emitShapeError(); } else if (indicesRank != 1 || rank != 1) { return emitShapeError(); diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -5,6 +5,7 @@ Attributes.cpp Block.cpp Builders.cpp + BuiltinAttributeInterfaces.cpp BuiltinAttributes.cpp BuiltinDialect.cpp BuiltinTypes.cpp @@ -36,6 +37,7 @@ DEPENDS MLIRBuiltinAttributesIncGen + MLIRBuiltinAttributeInterfacesIncGen MLIRBuiltinDialectIncGen MLIRBuiltinLocationAttributesIncGen MLIRBuiltinOpsIncGen diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -test-elements-attr-interface -verify-diagnostics + +// This test contains various `ElementsAttr` attributes, and tests the support +// for iterating the values of these attributes using various native C++ types. +// This tests that the abstract iteration of ElementsAttr works properly, and +// is properly failable when necessary. + +// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} +// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} +// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} +std.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> + +// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} +// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} +// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} +std.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64> + +// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +std.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64> diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -15,6 +15,7 @@ // To get the test dialect definition. include "TestOps.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" // All of the attributes will extend this class. class Test_Attr traits = []> @@ -63,4 +64,41 @@ let parameters = (ins ); } +// Test support for ElementsAttrInterface. +def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ + ElementsAttrInterface + ]> { + let mnemonic = "i64_elements"; + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + ArrayRefParameter<"uint64_t">:$elements + ); + let extraClassDeclaration = [{ + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + using NonContiguousIterableTypesT = std::tuple; + + /// Provide begin iterators for the various iterable types. + // * uint64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + // * Attribute + auto value_begin_impl(OverloadToken) const { + mlir::Type elementType = getType().getElementType(); + return llvm::map_range(getElements(), [=](uint64_t value) { + return mlir::IntegerAttr::get(elementType, + llvm::APInt(/*numBits=*/64, value)); + }).begin(); + } + // * APInt + auto value_begin_impl(OverloadToken) const { + return llvm::map_range(getElements(), [=](uint64_t value) { + return llvm::APInt(/*numBits=*/64, value); + }).begin(); + } + }]; + let genVerifyDecl = 1; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -89,6 +89,48 @@ printer << "]>"; } +//===----------------------------------------------------------------------===// +// CompoundAAttr +//===----------------------------------------------------------------------===// + +Attribute TestI64ElementsAttr::parse(MLIRContext *context, + DialectAsmParser &parser, Type type) { + SmallVector elements; + if (parser.parseLess() || parser.parseLSquare()) + return Attribute(); + uint64_t intVal; + while (succeeded(*parser.parseOptionalInteger(intVal))) { + elements.push_back(intVal); + if (parser.parseOptionalComma()) + break; + } + + if (parser.parseRSquare() || parser.parseGreater()) + return Attribute(); + return parser.getChecked( + context, type.cast(), elements); +} + +void TestI64ElementsAttr::print(DialectAsmPrinter &printer) const { + printer << "i64_elements<["; + llvm::interleaveComma(getElements(), printer); + printer << "] : " << getType() << ">"; +} + +LogicalResult +TestI64ElementsAttr::verify(function_ref emitError, + ShapedType type, ArrayRef elements) { + if (type.getNumElements() != static_cast(elements.size())) { + return emitError() + << "number of elements does not match the provided shape type, got: " + << elements.size() << ", but expected: " << type.getNumElements(); + } + if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) + return emitError() << "expected single rank 64-bit shape type, but got: " + << type; + return success(); +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestBuiltinAttributeInterfaces.cpp TestDiagnostics.cpp TestDominance.cpp TestFunc.cpp diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -0,0 +1,61 @@ +//===- TestBuiltinAttributeInterfaces.cpp ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestAttributes.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace test; + +namespace { +struct TestElementsAttrInterface + : public PassWrapper> { + StringRef getArgument() const final { return "test-elements-attr-interface"; } + StringRef getDescription() const final { + return "Test ElementsAttr interface support."; + } + void runOnOperation() override { + getOperation().walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) { + auto elementsAttr = attr.second.dyn_cast(); + if (!elementsAttr) + continue; + testElementsAttrIteration(op, elementsAttr, "uint64_t"); + testElementsAttrIteration(op, elementsAttr, "APInt"); + testElementsAttrIteration(op, elementsAttr, "IntegerAttr"); + } + }); + } + + template + void testElementsAttrIteration(Operation *op, ElementsAttr attr, + StringRef type) { + InFlightDiagnostic diag = op->emitError() + << "Test iterating `" << type << "`: "; + + auto values = attr.tryGetValues(); + if (!values) { + diag << "unable to iterate type"; + return; + } + + llvm::interleaveComma(*values, diag, [&](T value) { + diag << llvm::formatv("{0}", value).str(); + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestBuiltinAttributeInterfaces() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -62,6 +62,7 @@ void registerSimpleParametricTilingPass(); void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); +void registerTestBuiltinAttributeInterfaces(); void registerTestCallGraphPass(); void registerTestConstantFold(); void registerTestConvVectorization(); @@ -146,6 +147,7 @@ mlir::test::registerSimpleParametricTilingPass(); mlir::test::registerTestAffineLoopParametricTilingPass(); mlir::test::registerTestAliasAnalysisPass(); + mlir::test::registerTestBuiltinAttributeInterfaces(); mlir::test::registerTestCallGraphPass(); mlir::test::registerTestConstantFold(); mlir::test::registerTestDiagnosticsPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -106,6 +106,8 @@ td_library( name = "BuiltinDialectTdFiles", srcs = [ + "include/mlir/IR/BuiltinAttributeInterfaces.td", + "include/mlir/IR/BuiltinAttributes.td", "include/mlir/IR/BuiltinDialect.td", "include/mlir/IR/BuiltinLocationAttributes.td", "include/mlir/IR/BuiltinOps.td", @@ -159,6 +161,24 @@ deps = [":BuiltinDialectTdFiles"], ) +gentbl_cc_library( + name = "BuiltinAttributeInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/mlir/IR/BuiltinAttributeInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/mlir/IR/BuiltinAttributeInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/IR/BuiltinAttributeInterfaces.td", + deps = [":BuiltinDialectTdFiles"], +) + gentbl_cc_library( name = "BuiltinLocationAttributesIncGen", strip_include_prefix = "include", @@ -249,6 +269,7 @@ ], includes = ["include"], deps = [ + ":BuiltinAttributeInterfacesIncGen", ":BuiltinAttributesIncGen", ":BuiltinDialectIncGen", ":BuiltinLocationAttributesIncGen", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -35,6 +35,7 @@ name = "TestOpTdFiles", srcs = glob(["lib/Dialect/Test/*.td"]), deps = [ + "//mlir:BuiltinDialectTdFiles", "//mlir:CallInterfacesTdFiles", "//mlir:ControlFlowInterfacesTdFiles", "//mlir:CopyOpInterfaceTdFiles",