diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -26,6 +26,8 @@ class Location; class Operation; class ShapedType; +class AsmParser; +class AsmPrinter; //===----------------------------------------------------------------------===// // Elements Attributes @@ -64,9 +66,96 @@ template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; + +class DenseArrayAttributeStorage; +// All possible element types for DenseArrayAttribute. +enum class DenseArrayAttributeElementType { I8, I16, I32, I64, F32, F64 }; +} // namespace detail + +// Base class for DenseArrayAttr. This attribute wraps a simple ArrayRef. +// The various subclasses DenseArrayAttr are used to access this attribute. +class DenseArrayAttrBase + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } + + // ElementsAttr implementation. + using ContiguousIterableTypesT = + std::tuple; + const int8_t *value_begin_impl(OverloadToken) const; + const int16_t *value_begin_impl(OverloadToken) const; + const int32_t *value_begin_impl(OverloadToken) const; + const int64_t *value_begin_impl(OverloadToken) const; + const float *value_begin_impl(OverloadToken) const; + const double *value_begin_impl(OverloadToken) const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + detail::DenseArrayAttributeElementType getElementType() const; + /// Printer for the short form: will dispatch to the appropriate subclass. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + /// Print the short form `42, 100, -1` without any braces or prefix. + void printWithoutBraces(raw_ostream &os) const; +}; + +namespace detail { +/// Base class for DenseArrayAttr that is instantiate and specialized for each +/// supported element types below. +template +class DenseArrayAttr : public DenseArrayAttrBase { +public: + using DenseArrayAttrBase::DenseArrayAttrBase; + + /// Implicit conversion to ArrayRef. + operator ArrayRef() const; + ArrayRef asArrayRef() { return ArrayRef{*this}; } + + /// Builder from ArrayRef. + static DenseArrayAttr get(MLIRContext *context, ArrayRef content); + + /// Print the short form `[42, 100, -1]` without any prefix. + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + /// Print the short form `42, 100, -1` without any braces or prefix. + void printWithoutBraces(raw_ostream &os) const; + + /// Parse the short form `[42, 100, -1]` without any prefix. + static Attribute parse(AsmParser &parser, Type odsType); + + /// Parse the short form `42, 100, -1` without any prefix or braces. + static Attribute parseWithoutBraces(AsmParser &parser, Type odsType); + + // Support for isa<>/cast<>. + static bool classof(Attribute attr); +}; +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; + +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; } // namespace detail -/// An attribute that represents a reference to a dense vector or tensor object. +// Public name for all the supported DenseArrayAttr +using DenseI8ArrayAttr = detail::DenseArrayAttr; +using DenseI16ArrayAttr = detail::DenseArrayAttr; +using DenseI32ArrayAttr = detail::DenseArrayAttr; +using DenseI64ArrayAttr = detail::DenseArrayAttr; +using DenseF32ArrayAttr = detail::DenseArrayAttr; +using DenseF64ArrayAttr = detail::DenseArrayAttr; + +/// An attribute that represents a reference to a dense vector or tensor +/// object. /// class DenseElementsAttr : public Attribute { public: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1258,6 +1258,60 @@ let convertFromStorage = "$_self"; } +def DenseI8ArrayAttr : + ElementsAttrBase()">, + "i8 dense array attribute"> { + let storageType = [{ ::mlir::DenseI8ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI16ArrayAttr : + ElementsAttrBase()">, + "i16 dense array attribute"> { + let storageType = [{ ::mlir::DenseI16ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI32ArrayAttr : + ElementsAttrBase()">, + "i32 dense array attribute"> { + let storageType = [{ ::mlir::DenseI32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI64ArrayAttr : + ElementsAttrBase()">, + "i64 dense array attribute"> { + let storageType = [{ ::mlir::DenseI64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF32ArrayAttr : + ElementsAttrBase()">, + "float dense array attribute"> { + let storageType = [{ ::mlir::DenseF32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF64ArrayAttr : + ElementsAttrBase()">, + "double dense array attribute"> { + let storageType = [{ ::mlir::DenseF64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + def IndexElementsAttr : IntElementsAttrBase() .getType() diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1878,9 +1878,34 @@ } os << '>'; } - + } else if (auto denseArrayAttr = attr.dyn_cast()) { + typeElision = AttrTypeElision::Must; + switch (denseArrayAttr.getElementType()) { + case detail::DenseArrayAttributeElementType::I8: + os << "[:i8 "; + break; + case detail::DenseArrayAttributeElementType::I16: + os << "[:i16 "; + break; + case detail::DenseArrayAttributeElementType::I32: + os << "[:i32 "; + break; + case detail::DenseArrayAttributeElementType::I64: + os << "[:i64 "; + break; + case detail::DenseArrayAttributeElementType::F32: + os << "[:f32 "; + break; + case detail::DenseArrayAttributeElementType::F64: + os << "[:f64 "; + break; + } + denseArrayAttr.printWithoutBraces(os); + os << "]"; } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); + } else { + llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must && !attrType.isa()) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -40,6 +40,40 @@ return eltType.getIntOrFloatBitWidth(); } +/// Storage for DenseArrayAttribute. +class DenseArrayAttributeStorage : public AttributeStorage { +public: + using KeyTy = std::tuple>; + DenseArrayAttributeStorage(ShapedType type, + DenseArrayAttributeElementType eltType, + ArrayRef elements) + : AttributeStorage(type), eltType(eltType), elements(elements) {} + + bool operator==(const KeyTy &tblgenKey) const { + return (getType() == std::get<0>(tblgenKey)) && + (eltType == std::get<1>(tblgenKey)) && + (elements == std::get<2>(tblgenKey)); + } + + static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { + return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); + } + + static DenseArrayAttributeStorage * + construct(::mlir::AttributeStorageAllocator &allocator, + const KeyTy &tblgenKey) { + auto eltType = std::get<1>(tblgenKey); + auto elements = std::get<2>(tblgenKey); + auto type = std::get<0>(tblgenKey); + elements = allocator.copyInto(elements); + return new (allocator.allocate()) + DenseArrayAttributeStorage(type, eltType, elements); + } + DenseArrayAttributeElementType eltType; + ::llvm::ArrayRef elements; +}; + /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { public: diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -35,11 +36,11 @@ //===----------------------------------------------------------------------===// void BuiltinDialect::registerAttributes() { - addAttributes(); + addAttributes(); } //===----------------------------------------------------------------------===// @@ -664,6 +665,235 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +detail::DenseArrayAttributeElementType +DenseArrayAttrBase::getElementType() const { + return getImpl()->eltType; +} + +const int8_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int16_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int32_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const int64_t * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const float *DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const double * +DenseArrayAttrBase::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} + +void DenseArrayAttrBase::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +void DenseArrayAttrBase::printWithoutBraces(raw_ostream &os) const { + switch (getElementType()) { + case detail::DenseArrayAttributeElementType::I8: + this->cast().printWithoutBraces(os); + return; + case detail::DenseArrayAttributeElementType::I16: + this->cast().printWithoutBraces(os); + return; + case detail::DenseArrayAttributeElementType::I32: + this->cast().printWithoutBraces(os); + return; + case detail::DenseArrayAttributeElementType::I64: + this->cast().printWithoutBraces(os); + return; + case detail::DenseArrayAttributeElementType::F32: + this->cast().printWithoutBraces(os); + return; + case detail::DenseArrayAttributeElementType::F64: + this->cast().printWithoutBraces(os); + return; + } + llvm_unreachable(""); +} + +void DenseArrayAttrBase::print(raw_ostream &os) const { + os << "["; + printWithoutBraces(os); + os << "]"; +} + +template +void DenseArrayAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +template +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + ArrayRef values{*this}; + llvm::interleaveComma(values, os); +} + +/// Specialization for int8_t for forcing printing as number instead of chars. +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + ArrayRef values{*this}; + llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); +} + +template +void DenseArrayAttr::print(raw_ostream &os) const { + os << "["; + printWithoutBraces(os); + os << "]"; +} + +/// Parse a single element: generic template for int types, specialized for +/// floating points below. +template +static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { + return parser.parseInteger(value); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, float &value) { + double doubleVal; + if (parser.parseFloat(doubleVal)) + return failure(); + value = doubleVal; + return success(); +} + +template <> +ParseResult parseDenseArrayAttrElt(AsmParser &parser, double &value) { + return parser.parseFloat(value); +} + +/// Parse a DenseArrayAttr without the braces: `1, 2, 3` +template +Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, + Type odsType) { + SmallVector data; + do { + T value; + if (parseDenseArrayAttrElt(parser, value)) + return {}; + data.push_back(value); + if (parser.parseOptionalComma()) + break; + } while (1); + return get(parser.getContext(), data); +} + +/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` +template +Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { + if (parser.parseLSquare()) + return {}; + Attribute result = parseWithoutBraces(parser, odsType); + if (parser.parseRSquare()) + return {}; + return result; +} + +/// Conversion from DenseArrayAttr to ArrayRef. +template +DenseArrayAttr::operator ArrayRef() const { + ArrayRef raw = getImpl()->elements; + assert((raw.size() % sizeof(T)) == 0); + return ArrayRef(reinterpret_cast(raw.data()), + raw.size() / sizeof(T)); +} + +namespace { +/// Mapping from C++ element type to MLIR DenseArrayAttr internals. +template +struct denseArrayAttrEltTypeBuilder; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I8; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 8)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I16; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 16)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 32)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::I64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, IntegerType::get(context, 64)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::F32; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float32Type::get(context)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = detail::DenseArrayAttributeElementType::F64; + static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + return VectorType::get(shape, Float64Type::get(context)); + } +}; +} // namespace + +/// Builds a DenseArrayAttr from an ArrayRef. +template +DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, + ArrayRef content) { + auto shapedType = + denseArrayAttrEltTypeBuilder::getShapedType(context, content.size()); + auto eltType = denseArrayAttrEltTypeBuilder::eltType; + auto rawArray = ArrayRef(reinterpret_cast(content.data()), + content.size() * sizeof(T)); + return Base::get(context, shapedType, eltType, rawArray) + .template cast>(); +} + +template +bool DenseArrayAttr::classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + denseArrayAttrEltTypeBuilder::eltType; +} + +namespace mlir { +namespace detail { +// Explicit instantiation for all the supported DenseArrayAttr. +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +} // namespace detail +} // namespace mlir + //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -11,9 +11,12 @@ //===----------------------------------------------------------------------===// #include "Parser.h" + +#include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Parser/AsmParserState.h" #include "llvm/ADT/StringExtras.h" @@ -35,6 +38,7 @@ /// | symbol-ref-id (`::` symbol-ref-id)* /// | `dense` `<` tensor-literal `>` `:` /// (tensor-type | vector-type) +/// | arrayref (integer-type | float-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` /// `:` (tensor-type | vector-type) /// | `opaque` `<` dialect-namespace `,` hex-string-literal @@ -67,14 +71,22 @@ // Parse an array attribute. case Token::l_square: { + consumeToken(Token::l_square); + if (consumeIf(Token::colon)) + return parseDenseArrayAttr(); SmallVector elements; auto parseElt = [&]() -> ParseResult { elements.push_back(parseAttribute()); return elements.back() ? success() : failure(); }; - if (parseCommaSeparatedList(Delimiter::Square, parseElt)) + if (!getToken().is(Token::r_square) && + parseCommaSeparatedList(Delimiter::None, parseElt)) + return nullptr; + if (!consumeIf(Token::r_square)) { + emitWrongTokenError("expected closing ']' for array attribute"); return nullptr; + } return builder.getArrayAttr(elements); } @@ -812,6 +824,65 @@ // ElementsAttr Parser //===----------------------------------------------------------------------===// +namespace { +/// This class provides an implementation of AsmParser, allowing to call back +/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined +/// in libMLIRIR. +class CustomAsmParser : public AsmParserImpl { +public: + CustomAsmParser(Parser &parser) + : AsmParserImpl(parser.getToken().getLoc(), parser) {} +}; +} // namespace + +/// Parse a dense elements attribute. +Attribute Parser::parseDenseArrayAttr() { + auto typeLoc = getToken().getLoc(); + auto type = parseType(); + CustomAsmParser parser(*this); + Attribute result; + if (auto intType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 8: + result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 16: + result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 32: + result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 64: + result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + default: + emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; + return {}; + } + } else if (auto floatType = type.dyn_cast()) { + switch (type.getIntOrFloatBitWidth()) { + case 32: + result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + case 64: + result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); + break; + default: + emitError(typeLoc, "expected f32 or f64 but got: ") << type; + return {}; + } + } else { + emitError(typeLoc, "expected integer or float type, got: ") << type; + } + if (!consumeIf(Token::r_square)) { + emitError("expected ']' to close an array attribute"); + return {}; + } + if (!result) + emitError("UNKNOWN ERROR"); + return result; +} + /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -136,6 +136,8 @@ /// current token is. This is preferred to the above method because it leads /// to more self-documenting code with better checking. void consumeToken(Token::Kind kind) { + if (!state.curToken.is(kind)) + llvm::errs() << state.curToken.getSpelling() << " vs " << kind << "\n"; assert(state.curToken.is(kind) && "consumed an unexpected token"); consumeToken(); } @@ -264,6 +266,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a DenseArrayAttr. + Attribute parseDenseArrayAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -513,6 +513,46 @@ return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test DenseArrayAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dense_array_attr +func.func @dense_array_attr() attributes{ +// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03], + f32attr = [:f32 1024., 453., -6435.], +// CHECK-SAME: f64attr = [:f64 -1.420000e+02], + f64attr = [:f64 -142.], +// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10], + i16attr = [:i16 3, 5, -4, 10], +// CHECK-SAME: i32attr = [:i32 1024, 453, -6435], + i32attr = [:i32 1024, 453, -6435], +// CHECK-SAME: i64attr = [:i64 -142], + i64attr = [:i64 -142], +// CHECK-SAME: i8attr = [:i8 1, -2, 3] + i8attr = [:i8 1, -2, 3] + } { +// CHECK: test.dense_array_attr + test.dense_array_attr +// CHECK-SAME: i8attr = [1, -2, 3] + i8attr = [1, -2, 3] +// CHECK-SAME: i16attr = [3, 5, -4, 10] + i16attr = [3, 5, -4, 10] +// CHECK-SAME: i32attr = [1024, 453, -6435] + i32attr = [1024, 453, -6435] +// CHECK-SAME: i64attr = [-142] + i64attr = [-142] +// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03] + f32attr = [1024., 453., -6435.] +// CHECK-SAME: f64attr = [-1.420000e+02] + f64attr = [-142.] + + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -5,23 +5,40 @@ // This tests that the abstract iteration of ElementsAttr works properly, and // is properly failable when necessary. +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} // expected-error@below {{Test iterating `APInt`: unable to iterate type}} // expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64> // Check that we don't crash on empty element attributes. +// expected-error@below {{Test iterating `int64_t`: }} // expected-error@below {{Test iterating `uint64_t`: }} // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} +arith.constant [:i8 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} +arith.constant [:i16 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}} +arith.constant [:i32 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} +arith.constant [:i64 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant [:f32 10., 11., -12., 13., 14.] +// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} +arith.constant [:f64 10., 11., -12., 13., 14.] diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1654,7 +1654,7 @@ // ----- -// expected-error@+1 {{expected ']'}} +// expected-error@+1 {{expected closing ']' for array attribute}} "f"() { b = [@m: // ----- diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -270,6 +270,23 @@ ); } +def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { + let arguments = (ins + DenseI8ArrayAttr:$i8attr, + DenseI16ArrayAttr:$i16attr, + DenseI32ArrayAttr:$i32attr, + DenseI64ArrayAttr:$i64attr, + DenseF32ArrayAttr:$f32attr, + DenseF64ArrayAttr:$f64attr + ); + let assemblyFormat = [{ + `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr + `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr + attr-dict + }]; +} + + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -14,6 +14,17 @@ using namespace mlir; using namespace test; +// Helper to print one scalar value, force int8_t to print as integer instead of +// char. +template +static void printOneElement(InFlightDiagnostic &os, T value) { + os << llvm::formatv("{0}", value).str(); +} +template <> +void printOneElement(InFlightDiagnostic &os, int8_t value) { + os << llvm::formatv("{0}", static_cast(value)).str(); +} + namespace { struct TestElementsAttrInterface : public PassWrapper> { @@ -29,6 +40,31 @@ auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; + if (auto concreteAttr = + attr.getValue().dyn_cast()) { + switch (concreteAttr.getElementType()) { + case mlir::detail::DenseArrayAttributeElementType::I8: + testElementsAttrIteration(op, elementsAttr, "int8_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I16: + testElementsAttrIteration(op, elementsAttr, "int16_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I32: + testElementsAttrIteration(op, elementsAttr, "int32_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::I64: + testElementsAttrIteration(op, elementsAttr, "int64_t"); + break; + case mlir::detail::DenseArrayAttributeElementType::F32: + testElementsAttrIteration(op, elementsAttr, "float"); + break; + case mlir::detail::DenseArrayAttributeElementType::F64: + testElementsAttrIteration(op, elementsAttr, "double"); + break; + } + continue; + } + testElementsAttrIteration(op, elementsAttr, "int64_t"); testElementsAttrIteration(op, elementsAttr, "uint64_t"); testElementsAttrIteration(op, elementsAttr, "APInt"); testElementsAttrIteration(op, elementsAttr, "IntegerAttr"); @@ -48,9 +84,8 @@ return; } - llvm::interleaveComma(*values, diag, [&](T value) { - diag << llvm::formatv("{0}", value).str(); - }); + llvm::interleaveComma(*values, diag, + [&](T value) { printOneElement(diag, value); }); } }; } // namespace