diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -26,6 +26,8 @@ class Location; class Operation; class ShapedType; +class AsmParser; +class AsmPrinter; //===----------------------------------------------------------------------===// // Elements Attributes @@ -64,9 +66,89 @@ template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; + +class DenseArrayAttributeStorage; +// All possible element types for DenseArrayAttribute. +enum class DenseArrayAttributeElementType { I8, I16, I32, I64, F32, F64 }; } // namespace detail -/// An attribute that represents a reference to a dense vector or tensor object. +// Base class for DenseArrayAttr. This attribute wraps a simple ArrayRef. +// The various subclasses DenseArrayAttr are used to access this attribute. +class DenseArrayAttrBase + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } + + // ElementsAttr implementation. + using ContiguousIterableTypesT = + std::tuple; + const int8_t *value_begin_impl(OverloadToken) const; + const int16_t *value_begin_impl(OverloadToken) const; + const int32_t *value_begin_impl(OverloadToken) const; + const int64_t *value_begin_impl(OverloadToken) const; + const float *value_begin_impl(OverloadToken) const; + const double *value_begin_impl(OverloadToken) const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + detail::DenseArrayAttributeElementType getElementType() const; + /// Printer for the short form: will dispatch to the appropriate subclass. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; +}; + +namespace detail { +/// Base class for DenseArrayAttr that is instantiate and specialized for each +/// supported element types below. +template +class DenseArrayAttr : public DenseArrayAttrBase { +public: + using DenseArrayAttrBase::DenseArrayAttrBase; + + /// Implicit conversion to ArrayRef. + operator ArrayRef() const; + ArrayRef asArrayRef() { return ArrayRef{*this}; } + + /// Builder from ArrayRef. + static DenseArrayAttr get(MLIRContext *context, ArrayRef content); + + /// Print the short form `[42, 100, -1]` without any prefix. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + + /// Parse the short form `[42, 100, -1]` without any prefix. + static Attribute parse(AsmParser &parser, Type odsType); + + // Support for isa<>/cast<>. + static bool classof(Attribute attr); +}; +template <> +void DenseArrayAttr::print(raw_ostream &os) const; + +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +} // namespace detail + +// Public name for all the supported DenseArrayAttr +using DenseI8ArrayAttr = detail::DenseArrayAttr; +using DenseI16ArrayAttr = detail::DenseArrayAttr; +using DenseI32ArrayAttr = detail::DenseArrayAttr; +using DenseI64ArrayAttr = detail::DenseArrayAttr; +using DenseF32ArrayAttr = detail::DenseArrayAttr; +using DenseF64ArrayAttr = detail::DenseArrayAttr; + +/// An attribute that represents a reference to a dense vector or tensor +/// object. /// class DenseElementsAttr : public Attribute { public: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1517,6 +1517,60 @@ let convertFromStorage = "$_self"; } +def DenseI8ArrayAttr : + ElementsAttrBase()">, + "i8 dense array attribute"> { + let storageType = [{ ::mlir::DenseI8ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI16ArrayAttr : + ElementsAttrBase()">, + "i16 dense array attribute"> { + let storageType = [{ ::mlir::DenseI16ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI32ArrayAttr : + ElementsAttrBase()">, + "i32 dense array attribute"> { + let storageType = [{ ::mlir::DenseI32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI64ArrayAttr : + ElementsAttrBase()">, + "i64 dense array attribute"> { + let storageType = [{ ::mlir::DenseI64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF32ArrayAttr : + ElementsAttrBase()">, + "float dense array attribute"> { + let storageType = [{ ::mlir::DenseF32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF64ArrayAttr : + ElementsAttrBase()">, + "double dense array attribute"> { + let storageType = [{ ::mlir::DenseF64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + def IndexElementsAttr : IntElementsAttrBase() .getType() diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1822,9 +1822,33 @@ } os << '>'; } - + } else if (auto denseArrayAttr = attr.dyn_cast()) { + os << "arrayref "; + switch (denseArrayAttr.getElementType()) { + case detail::DenseArrayAttributeElementType::I8: + os << "i8 "; + break; + case detail::DenseArrayAttributeElementType::I16: + os << "i16 "; + break; + case detail::DenseArrayAttributeElementType::I32: + os << "i32 "; + break; + case detail::DenseArrayAttributeElementType::I64: + os << "i64 "; + break; + case detail::DenseArrayAttributeElementType::F32: + os << "f32 "; + break; + case detail::DenseArrayAttributeElementType::F64: + os << "f64 "; + break; + } + denseArrayAttr.print(os); } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); + } else { + llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must && !attrType.isa()) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -40,6 +40,39 @@ return eltType.getIntOrFloatBitWidth(); } +/// Storage for DenseArrayAttribute. +struct DenseArrayAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple>; + DenseArrayAttributeStorage(ShapedType type, + DenseArrayAttributeElementType eltType, + ArrayRef elements) + : AttributeStorage(type), eltType(eltType), elements(elements) {} + + bool operator==(const KeyTy &tblgenKey) const { + return (getType() == std::get<0>(tblgenKey)) && + (eltType == std::get<1>(tblgenKey)) && + (elements == std::get<2>(tblgenKey)); + } + + static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { + return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); + } + + static DenseArrayAttributeStorage * + construct(::mlir::AttributeStorageAllocator &allocator, + const KeyTy &tblgenKey) { + auto eltType = std::get<1>(tblgenKey); + auto elements = std::get<2>(tblgenKey); + auto type = std::get<0>(tblgenKey); + elements = allocator.copyInto(elements); + return new (allocator.allocate()) + DenseArrayAttributeStorage(type, eltType, elements); + } + DenseArrayAttributeElementType eltType; + ::llvm::ArrayRef elements; +}; + /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { public: diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -35,11 +36,11 @@ //===----------------------------------------------------------------------===// void BuiltinDialect::registerAttributes() { - addAttributes(); + addAttributes(); } //===----------------------------------------------------------------------===// @@ -664,6 +665,221 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +detail::DenseArrayAttributeElementType +DenseArrayAttrBase::getElementType() const { + return getImpl()->eltType; +} + +const int8_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int16_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int32_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int64_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const float *DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const double * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} + +void DenseArrayAttrBase::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +void DenseArrayAttrBase::print(raw_ostream &os) const { + switch (getElementType()) { + case detail::DenseArrayAttributeElementType::I8: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I16: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I32: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I64: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::F32: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::F64: + this->cast().print(os); + return; + default: + os << ""; + return; + } +} + +template +void DenseArrayAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +template +void DenseArrayAttr::print(raw_ostream &os) const { + ArrayRef values{*this}; + os << "["; + llvm::interleaveComma(values, os); + os << "]"; +} + +/// Specialization for int8_t for forcing printing as number instead of chars. +template <> +void DenseArrayAttr::print(raw_ostream &os) const { + ArrayRef values{*this}; + os << "["; + llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); + os << "]"; +} + +/// Parse a single element: generic template for int types, specialized for +/// floating points below. +template +static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { + return parser.parseInteger(value); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, float &value) { + double doubleVal; + if (parser.parseFloat(doubleVal)) + return failure(); + value = doubleVal; + return success(); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, double &value) { + return parser.parseFloat(value); +} + +/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` +template +Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { + if (parser.parseLSquare()) + return {}; + SmallVector data; + do { + T value; + if (parseDenseArrayAttrElt(parser, value)) + return {}; + data.push_back(value); + if (parser.parseOptionalComma()) + break; + } while (1); + + if (parser.parseRSquare()) + return {}; + return get(parser.getContext(), data); +} + +/// Conversion from DenseArrayAttr to ArrayRef. +template +DenseArrayAttr::operator ArrayRef() const { + ArrayRef raw = getImpl()->elements; + assert((raw.size() % sizeof(T)) == 0); + return ArrayRef(reinterpret_cast(raw.data()), + raw.size() / sizeof(T)); +} + +namespace { +/// Mapping from C++ element type to MLIR DenseArrayAttr internals. +template +struct denseArrayAttrEltTypeBuilder; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I8; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 8)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I16; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 16)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 32)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 64)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::F32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float32Type::get(context)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::F64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float64Type::get(context)); + } +}; +} // namespace + +/// Builds a DenseArrayAttr from an ArrayRef. +template +DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, + ArrayRef content) { + auto shapedType = + denseArrayAttrEltTypeBuilder::getShapedType(context, content.size()); + auto eltType = denseArrayAttrEltTypeBuilder::eltType; + auto rawArray = ArrayRef(reinterpret_cast(content.data()), + content.size() * sizeof(T)); + return Base::get(context, shapedType, eltType, rawArray) + .template cast>(); +} + +template +bool DenseArrayAttr::classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + denseArrayAttrEltTypeBuilder::eltType; +} + +namespace mlir { +namespace detail { +// Explicit instantiation for all the supported DenseArrayAttr. +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +} // namespace detail +} // namespace mlir + //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -11,9 +11,12 @@ //===----------------------------------------------------------------------===// #include "Parser.h" + +#include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Parser/AsmParserState.h" #include "llvm/ADT/StringExtras.h" @@ -35,6 +38,7 @@ /// | symbol-ref-id (`::` symbol-ref-id)* /// | `dense` `<` tensor-literal `>` `:` /// (tensor-type | vector-type) +/// | arrayref (integer-type | float-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` /// `:` (tensor-type | vector-type) /// | `opaque` `<` dialect-namespace `,` hex-string-literal @@ -90,6 +94,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a DenseArrayAttr + case Token::kw_arrayref: + return parseDenseArrayAttr(); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -806,6 +814,54 @@ // ElementsAttr Parser //===----------------------------------------------------------------------===// +namespace { +/// This class provides an implementation of AsmParser, allowing to call back +/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined +/// in libMLIRIR. +class CustomAsmParser : public AsmParserImpl { +public: + CustomAsmParser(Parser &parser) + : AsmParserImpl(parser.getToken().getLoc(), parser) {} +}; +} // namespace + +/// Parse a dense elements attribute. +Attribute Parser::parseDenseArrayAttr() { + consumeToken(Token::kw_arrayref); + + auto typeLoc = getToken().getLoc(); + auto type = parseType(); + CustomAsmParser parser(*this); + if (auto intType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 8: + return DenseI8ArrayAttr::parse(parser, Type{}); + case 16: + return DenseI16ArrayAttr::parse(parser, Type{}); + case 32: + return DenseI32ArrayAttr::parse(parser, Type{}); + case 64: + return DenseI64ArrayAttr::parse(parser, Type{}); + default: + emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; + return {}; + } + } + if (auto floatType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 32: + return DenseF32ArrayAttr::parse(parser, Type{}); + case 64: + return DenseF64ArrayAttr::parse(parser, Type{}); + default: + emitError(typeLoc, "expected f32 or f64 but got: ") << type; + return {}; + } + } + emitError(typeLoc, "expected integer or float type, got: ") << type; + return {}; +} + /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -265,6 +265,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a DenseArrayAttr. + Attribute parseDenseArrayAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -78,6 +78,7 @@ TOK_KEYWORD(affine_map) TOK_KEYWORD(affine_set) TOK_KEYWORD(attributes) +TOK_KEYWORD(arrayref) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -500,6 +500,34 @@ return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test DenseArrayAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dense_array_attr +// CHECK: i16attr = arrayref i16 [1, -2, 3] +func @dense_array_attr() attributes{ i16attr = arrayref i16 [1, -2, 3] } { +// CHECK: test.dense_array_attr +// CHECK-SAME: i8attr = [1, -2, 3] +// CHECK-SAME: i16attr = [3, 5, -4, 10] +// CHECK-SAME: i32attr = [1024, 453, -6435] +// CHECK-SAME: i64attr = [-142] +// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03] +// CHECK-SAME: f64attr = [-1.420000e+02] + test.dense_array_attr + i8attr = [1, -2, 3] + i16attr = [3, 5, -4, 10] + i32attr = [1024, 453, -6435] + i64attr = [-142] + f32attr = [1024., 453., -6435.] + f64attr = [-142.] + + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -5,23 +5,40 @@ // This tests that the abstract iteration of ElementsAttr works properly, and // is properly failable when necessary. +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} // expected-error@below {{Test iterating `APInt`: unable to iterate type}} // expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64> // Check that we don't crash on empty element attributes. +// expected-error@below {{Test iterating `int64_t`: }} // expected-error@below {{Test iterating `uint64_t`: }} // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} +arith.constant arrayref i8 [10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} +arith.constant arrayref i16 [10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}} +arith.constant arrayref i32 [10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} +arith.constant arrayref i64 [10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant arrayref f32 [10., 11., -12., 13., 14.] +// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant arrayref f64 [10., 11., -12., 13., 14.] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -289,6 +289,23 @@ ); } +def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { + let arguments = (ins + DenseI8ArrayAttr:$i8attr, + DenseI16ArrayAttr:$i16attr, + DenseI32ArrayAttr:$i32attr, + DenseI64ArrayAttr:$i64attr, + DenseF32ArrayAttr:$f32attr, + DenseF64ArrayAttr:$f64attr + ); + let assemblyFormat = [{ + `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr + `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr + attr-dict + }]; +} + + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -14,6 +14,17 @@ using namespace mlir; using namespace test; +// Helper to print one scalar value, force int8_t to print as integer instead of +// char. +template +static void printOneElement(InFlightDiagnostic &os, T value) { + os << llvm::formatv("{0}", value).str(); +} +template <> +void printOneElement(InFlightDiagnostic &os, int8_t value) { + os << llvm::formatv("{0}", static_cast(value)).str(); +} + namespace { struct TestElementsAttrInterface : public PassWrapper> { @@ -29,6 +40,33 @@ auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; + if (auto concreteAttr = + attr.getValue().dyn_cast()) { + switch (concreteAttr.getElementType()) { + case mlir::detail::DenseArrayAttributeElementType::I8: + testElementsAttrIteration(op, elementsAttr, "int8_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I16: + testElementsAttrIteration(op, elementsAttr, "int16_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I32: + testElementsAttrIteration(op, elementsAttr, "int32_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I64: + testElementsAttrIteration(op, elementsAttr, "int64_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::F32: + testElementsAttrIteration(op, elementsAttr, "float"); + break; + case mlir::detail::DenseArrayAttributeElementType::F64: + testElementsAttrIteration(op, elementsAttr, "double"); + break; + default: + break; + } + continue; + } + testElementsAttrIteration(op, elementsAttr, "int64_t"); testElementsAttrIteration(op, elementsAttr, "uint64_t"); testElementsAttrIteration(op, elementsAttr, "APInt"); testElementsAttrIteration(op, elementsAttr, "IntegerAttr"); @@ -48,9 +86,8 @@ return; } - llvm::interleaveComma(*values, diag, [&](T value) { - diag << llvm::formatv("{0}", value).str(); - }); + llvm::interleaveComma(*values, diag, + [&](T value) { printOneElement(diag, value); }); } }; } // namespace