diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -66,8 +66,8 @@ struct is_complex_t> : public std::true_type {}; } // namespace detail -/// An attribute that represents a reference to a dense vector or tensor object. -/// +/// An attribute that represents a reference to a dense vector or tensor +/// object. class DenseElementsAttr : public Attribute { public: using Attribute::Attribute; @@ -743,6 +743,55 @@ //===----------------------------------------------------------------------===// namespace mlir { +namespace detail { +/// Base class for DenseArrayAttr that is instantiated and specialized for each +/// supported element type below. +template +class DenseArrayAttr : public DenseArrayBaseAttr { +public: + using DenseArrayBaseAttr::DenseArrayBaseAttr; + + /// Implicit conversion to ArrayRef. + operator ArrayRef() const; + ArrayRef asArrayRef() { return ArrayRef{*this}; } + + /// Builder from ArrayRef. + static DenseArrayAttr get(MLIRContext *context, ArrayRef content); + + /// Print the short form `[42, 100, -1]` without any type prefix. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + /// Print the short form `42, 100, -1` without any braces or type prefix. + void printWithoutBraces(raw_ostream &os) const; + + /// Parse the short form `[42, 100, -1]` without any type prefix. + static Attribute parse(AsmParser &parser, Type odsType); + + /// Parse the short form `42, 100, -1` without any type prefix or braces. + static Attribute parseWithoutBraces(AsmParser &parser, Type odsType); + + /// Support for isa<>/cast<>. + static bool classof(Attribute attr); +}; +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; + +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +} // namespace detail + +// Public name for all the supported DenseArrayAttr +using DenseI8ArrayAttr = detail::DenseArrayAttr; +using DenseI16ArrayAttr = detail::DenseArrayAttr; +using DenseI32ArrayAttr = detail::DenseArrayAttr; +using DenseI64ArrayAttr = detail::DenseArrayAttr; +using DenseF32ArrayAttr = detail::DenseArrayAttr; +using DenseF64ArrayAttr = detail::DenseArrayAttr; + //===----------------------------------------------------------------------===// // BoolAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -144,6 +144,76 @@ // DenseIntOrFPElementsAttr //===----------------------------------------------------------------------===// +def Builtin_DenseArrayBase : Builtin_Attr< + "DenseArrayBase", [ElementsAttrInterface]> { + let summary = "A dense array of i8, i16, i32, i64, f32, or f64."; + let description = [{ + A dense array attribute is an attribute that represents a dense array of + primitive element types. Contrary to DenseIntOrFPElementsAttr this is a + flat unidimensional array which does not have a storage optimization for + splat. This allows to expose the raw array through a C++ API as + `ArrayRef`. This is the base class attribute, the actual access is + intended to be managed through the subclasses `DenseI8ArrayAttr`, + `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`, + `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`. + + Syntax: + + ``` + dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]` + ``` + Examples: + + ```mlir + [:i8] + [:i32 10, 42] + [:f64 42., 12.] + ``` + + when a specific subclass is used as argument of an operation, the declarative + assembly will omit the type and print directly: + ``` + [1, 2, 3] + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, + "DenseArrayBaseAttr::EltType":$eltType, + ArrayRefParameter<"char">:$elements); + let extraClassDeclaration = [{ + // All possible supported element type. + enum class EltType { I8, I16, I32, I64, F32, F64 }; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } + + /// ElementsAttr implementation. + using ContiguousIterableTypesT = + std::tuple; + const int8_t *value_begin_impl(OverloadToken) const; + const int16_t *value_begin_impl(OverloadToken) const; + const int32_t *value_begin_impl(OverloadToken) const; + const int64_t *value_begin_impl(OverloadToken) const; + const float *value_begin_impl(OverloadToken) const; + const double *value_begin_impl(OverloadToken) const; + + /// Methods to support type inquiry through isa, cast, and dyn_cast. + EltType getElementType() const; + /// Printer for the short form: will dispatch to the appropriate subclass. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + /// Print the short form `42, 100, -1` without any braces or prefix. + void printWithoutBraces(raw_ostream &os) const; + }]; + let genAccessors = 0; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// DenseIntOrFPElementsAttr +//===----------------------------------------------------------------------===// + def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr" > { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1258,6 +1258,19 @@ let convertFromStorage = "$_self"; } +class DenseArrayAttrBase : + ElementsAttrBase()">, + summaryName # " dense array attribute"> { + let storageType = "::mlir::" # denseAttrName; + let returnType = "::llvm::ArrayRef<" # cppType # ">"; +} +def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">; +def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; +def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">; +def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; +def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">; +def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">; + def IndexElementsAttr : IntElementsAttrBase() .getType() diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1878,9 +1878,34 @@ } os << '>'; } - + } else if (auto denseArrayAttr = attr.dyn_cast()) { + typeElision = AttrTypeElision::Must; + switch (denseArrayAttr.getElementType()) { + case DenseArrayBaseAttr::EltType::I8: + os << "[:i8 "; + break; + case DenseArrayBaseAttr::EltType::I16: + os << "[:i16 "; + break; + case DenseArrayBaseAttr::EltType::I32: + os << "[:i32 "; + break; + case DenseArrayBaseAttr::EltType::I64: + os << "[:i64 "; + break; + case DenseArrayBaseAttr::EltType::F32: + os << "[:f32 "; + break; + case DenseArrayBaseAttr::EltType::F64: + os << "[:f64 "; + break; + } + denseArrayAttr.printWithoutBraces(os); + os << "]"; } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); + } else { + llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must && !attrType.isa()) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -35,11 +36,11 @@ //===----------------------------------------------------------------------===// void BuiltinDialect::registerAttributes() { - addAttributes(); + addAttributes(); } //===----------------------------------------------------------------------===// @@ -664,6 +665,234 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const { + return getImpl()->eltType; +} + +const int8_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int16_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int32_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int64_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const double * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} + +void DenseArrayBaseAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { + switch (getElementType()) { + case DenseArrayBaseAttr::EltType::I8: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::I16: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::I32: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::I64: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::F32: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::F64: + this->cast().printWithoutBraces(os); + return; + } + llvm_unreachable(""); +} + +void DenseArrayBaseAttr::print(raw_ostream &os) const { + os << "["; + printWithoutBraces(os); + os << "]"; +} + +template +void DenseArrayAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +template +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + ArrayRef values{*this}; + llvm::interleaveComma(values, os); +} + +/// Specialization for int8_t for forcing printing as number instead of chars. +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + ArrayRef values{*this}; + llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); +} + +template +void DenseArrayAttr::print(raw_ostream &os) const { + os << "["; + printWithoutBraces(os); + os << "]"; +} + +/// Parse a single element: generic template for int types, specialized for +/// floating points below. +template +static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { + return parser.parseInteger(value); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, float &value) { + double doubleVal; + if (parser.parseFloat(doubleVal)) + return failure(); + value = doubleVal; + return success(); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, double &value) { + return parser.parseFloat(value); +} + +/// Parse a DenseArrayAttr without the braces: `1, 2, 3` +template +Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, + Type odsType) { + SmallVector data; + if (failed(parser.parseCommaSeparatedList([&]() { + T value; + if (parseDenseArrayAttrElt(parser, value)) + return failure(); + data.push_back(value); + return success(); + }))) + return {}; + return get(parser.getContext(), data); +} + +/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` +template +Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { + if (parser.parseLSquare()) + return {}; + Attribute result = parseWithoutBraces(parser, odsType); + if (parser.parseRSquare()) + return {}; + return result; +} + +/// Conversion from DenseArrayAttr to ArrayRef. +template +DenseArrayAttr::operator ArrayRef() const { + ArrayRef raw = getImpl()->elements; + assert((raw.size() % sizeof(T)) == 0); + return ArrayRef(reinterpret_cast(raw.data()), + raw.size() / sizeof(T)); +} + +namespace { +/// Mapping from C++ element type to MLIR DenseArrayAttr internals. +template +struct denseArrayAttrEltTypeBuilder; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 8)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 16)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 32)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 64)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float32Type::get(context)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float64Type::get(context)); + } +}; +} // namespace + +/// Builds a DenseArrayAttr from an ArrayRef. +template +DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, + ArrayRef content) { + auto shapedType = + denseArrayAttrEltTypeBuilder::getShapedType(context, content.size()); + auto eltType = denseArrayAttrEltTypeBuilder::eltType; + auto rawArray = ArrayRef(reinterpret_cast(content.data()), + content.size() * sizeof(T)); + return Base::get(context, shapedType, eltType, rawArray) + .template cast>(); +} + +template +bool DenseArrayAttr::classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + denseArrayAttrEltTypeBuilder::eltType; +} + +namespace mlir { +namespace detail { +// Explicit instantiation for all the supported DenseArrayAttr. +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +} // namespace detail +} // namespace mlir + //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -11,9 +11,12 @@ //===----------------------------------------------------------------------===// #include "Parser.h" + +#include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Parser/AsmParserState.h" #include "llvm/ADT/StringExtras.h" @@ -30,6 +33,7 @@ /// | float-literal (`:` float-type)? /// | string-literal (`:` type)? /// | type +/// | `[` `:` (integer-type | float-type) tensor-literal `]` /// | `[` (attribute-value (`,` attribute-value)*)? `]` /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` /// | symbol-ref-id (`::` symbol-ref-id)* @@ -67,13 +71,16 @@ // Parse an array attribute. case Token::l_square: { + consumeToken(Token::l_square); + if (consumeIf(Token::colon)) + return parseDenseArrayAttr(); SmallVector elements; auto parseElt = [&]() -> ParseResult { elements.push_back(parseAttribute()); return elements.back() ? success() : failure(); }; - if (parseCommaSeparatedList(Delimiter::Square, parseElt)) + if (parseCommaSeparatedListUntil(Token::r_square, parseElt)) return nullptr; return builder.getArrayAttr(elements); } @@ -812,6 +819,66 @@ // ElementsAttr Parser //===----------------------------------------------------------------------===// +namespace { +/// This class provides an implementation of AsmParser, allowing to call back +/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined +/// in libMLIRIR. +class CustomAsmParser : public AsmParserImpl { +public: + CustomAsmParser(Parser &parser) + : AsmParserImpl(parser.getToken().getLoc(), parser) {} +}; +} // namespace + +/// Parse a dense array attribute. +Attribute Parser::parseDenseArrayAttr() { + auto typeLoc = getToken().getLoc(); + auto type = parseType(); + if (!type) + return {}; + CustomAsmParser parser(*this); + Attribute result; + if (auto intType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 8: + result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 16: + result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 32: + result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 64: + result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + default: + emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; + return {}; + } + } else if (auto floatType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 32: + result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 64: + result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + default: + emitError(typeLoc, "expected f32 or f64 but got: ") << type; + return {}; + } + } else { + emitError(typeLoc, "expected integer or float type, got: ") << type; + return {}; + } + if (!consumeIf(Token::r_square)) { + emitError("expected ']' to close an array attribute"); + return {}; + } + return result; +} + /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -264,6 +264,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a DenseArrayAttr. + Attribute parseDenseArrayAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -513,6 +513,45 @@ return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test DenseArrayAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dense_array_attr +func.func @dense_array_attr() attributes{ +// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03], + f32attr = [:f32 1024., 453., -6435.], +// CHECK-SAME: f64attr = [:f64 -1.420000e+02], + f64attr = [:f64 -142.], +// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10], + i16attr = [:i16 3, 5, -4, 10], +// CHECK-SAME: i32attr = [:i32 1024, 453, -6435], + i32attr = [:i32 1024, 453, -6435], +// CHECK-SAME: i64attr = [:i64 -142], + i64attr = [:i64 -142], +// CHECK-SAME: i8attr = [:i8 1, -2, 3] + i8attr = [:i8 1, -2, 3] + } { +// CHECK: test.dense_array_attr + test.dense_array_attr +// CHECK-SAME: i8attr = [1, -2, 3] + i8attr = [1, -2, 3] +// CHECK-SAME: i16attr = [3, 5, -4, 10] + i16attr = [3, 5, -4, 10] +// CHECK-SAME: i32attr = [1024, 453, -6435] + i32attr = [1024, 453, -6435] +// CHECK-SAME: i64attr = [-142] + i64attr = [-142] +// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03] + f32attr = [1024., 453., -6435.] +// CHECK-SAME: f64attr = [-1.420000e+02] + f64attr = [-142.] + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -5,23 +5,40 @@ // This tests that the abstract iteration of ElementsAttr works properly, and // is properly failable when necessary. +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} // expected-error@below {{Test iterating `APInt`: unable to iterate type}} // expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64> // Check that we don't crash on empty element attributes. +// expected-error@below {{Test iterating `int64_t`: }} // expected-error@below {{Test iterating `uint64_t`: }} // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} +arith.constant [:i8 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} +arith.constant [:i16 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}} +arith.constant [:i32 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} +arith.constant [:i64 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant [:f32 10., 11., -12., 13., 14.] +// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant [:f64 10., 11., -12., 13., 14.] diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1654,7 +1654,7 @@ // ----- -// expected-error@+1 {{expected ']'}} +// expected-error@+1 {{expected ',' or ']'}} "f"() { b = [@m: // ----- diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -270,6 +270,22 @@ ); } +def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { + let arguments = (ins + DenseI8ArrayAttr:$i8attr, + DenseI16ArrayAttr:$i16attr, + DenseI32ArrayAttr:$i32attr, + DenseI64ArrayAttr:$i64attr, + DenseF32ArrayAttr:$f32attr, + DenseF64ArrayAttr:$f64attr + ); + let assemblyFormat = [{ + `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr + `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr + attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -14,6 +14,17 @@ using namespace mlir; using namespace test; +// Helper to print one scalar value, force int8_t to print as integer instead of +// char. +template +static void printOneElement(InFlightDiagnostic &os, T value) { + os << llvm::formatv("{0}", value).str(); +} +template <> +void printOneElement(InFlightDiagnostic &os, int8_t value) { + os << llvm::formatv("{0}", static_cast(value)).str(); +} + namespace { struct TestElementsAttrInterface : public PassWrapper> { @@ -29,6 +40,31 @@ auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; + if (auto concreteAttr = + attr.getValue().dyn_cast()) { + switch (concreteAttr.getElementType()) { + case DenseArrayBaseAttr::EltType::I8: + testElementsAttrIteration(op, elementsAttr, "int8_t"); + break; + case DenseArrayBaseAttr::EltType::I16: + testElementsAttrIteration(op, elementsAttr, "int16_t"); + break; + case DenseArrayBaseAttr::EltType::I32: + testElementsAttrIteration(op, elementsAttr, "int32_t"); + break; + case DenseArrayBaseAttr::EltType::I64: + testElementsAttrIteration(op, elementsAttr, "int64_t"); + break; + case DenseArrayBaseAttr::EltType::F32: + testElementsAttrIteration(op, elementsAttr, "float"); + break; + case DenseArrayBaseAttr::EltType::F64: + testElementsAttrIteration(op, elementsAttr, "double"); + break; + } + continue; + } + testElementsAttrIteration(op, elementsAttr, "int64_t"); testElementsAttrIteration(op, elementsAttr, "uint64_t"); testElementsAttrIteration(op, elementsAttr, "APInt"); testElementsAttrIteration(op, elementsAttr, "IntegerAttr"); @@ -48,9 +84,8 @@ return; } - llvm::interleaveComma(*values, diag, [&](T value) { - diag << llvm::formatv("{0}", value).str(); - }); + llvm::interleaveComma(*values, diag, + [&](T value) { printOneElement(diag, value); }); } }; } // namespace