diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -782,22 +782,28 @@ void printWithoutBraces(raw_ostream &os) const; /// Parse the short form `[42, 100, -1]` without any type prefix. - static Attribute parse(AsmParser &parser, Type odsType); + static Attribute parse(AsmParser &parser, Type type); /// Parse the short form `42, 100, -1` without any type prefix or braces. - static Attribute parseWithoutBraces(AsmParser &parser, Type odsType); + static Attribute parseWithoutBraces(AsmParser &parser, Type type); /// Support for isa<>/cast<>. static bool classof(Attribute attr); }; template <> void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; } // namespace detail @@ -808,6 +814,10 @@ using DenseI16ArrayAttr = detail::DenseArrayAttr; using DenseI32ArrayAttr = detail::DenseArrayAttr; using DenseI64ArrayAttr = detail::DenseArrayAttr; +using DenseUI8ArrayAttr = detail::DenseArrayAttr; +using DenseUI16ArrayAttr = detail::DenseArrayAttr; +using DenseUI32ArrayAttr = detail::DenseArrayAttr; +using DenseUI64ArrayAttr = detail::DenseArrayAttr; using DenseF32ArrayAttr = detail::DenseArrayAttr; using DenseF64ArrayAttr = detail::DenseArrayAttr; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -180,7 +180,9 @@ ArrayRefParameter<"char">:$elements); let extraClassDeclaration = [{ // All possible supported element type. - enum class EltType { I1, I8, I16, I32, I64, F32, F64 }; + enum class EltType { + I1, I8, I16, I32, I64, UI8, UI16, UI32, UI64, F32, F64 + }; /// Allow implicit conversion to ElementsAttr. operator ElementsAttr() const { @@ -189,12 +191,17 @@ /// ElementsAttr implementation. using ContiguousIterableTypesT = - std::tuple; + std::tuple; const bool *value_begin_impl(OverloadToken) const; const int8_t *value_begin_impl(OverloadToken) const; const int16_t *value_begin_impl(OverloadToken) const; const int32_t *value_begin_impl(OverloadToken) const; const int64_t *value_begin_impl(OverloadToken) const; + const uint8_t *value_begin_impl(OverloadToken) const; + const uint16_t *value_begin_impl(OverloadToken) const; + const uint32_t *value_begin_impl(OverloadToken) const; + const uint64_t *value_begin_impl(OverloadToken) const; const float *value_begin_impl(OverloadToken) const; const double *value_begin_impl(OverloadToken) const; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -852,28 +852,56 @@ result = DenseI1ArrayAttr::parseWithoutBraces(parser, Type{}); break; case 8: - if (isEmptyList) - result = DenseI8ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); + if (type.isUnsignedInteger()) { + if (isEmptyList) + result = DenseUI8ArrayAttr::get(parser.getContext(), {}); + else + result = DenseUI8ArrayAttr::parseWithoutBraces(parser, Type{}); + } else { + if (isEmptyList) + result = DenseI8ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); + } break; case 16: - if (isEmptyList) - result = DenseI16ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); + if (type.isUnsignedInteger()) { + if (isEmptyList) + result = DenseUI16ArrayAttr::get(parser.getContext(), {}); + else + result = DenseUI16ArrayAttr::parseWithoutBraces(parser, Type{}); + } else { + if (isEmptyList) + result = DenseI16ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); + } break; case 32: - if (isEmptyList) - result = DenseI32ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); + if (type.isUnsignedInteger()) { + if (isEmptyList) + result = DenseUI32ArrayAttr::get(parser.getContext(), {}); + else + result = DenseUI32ArrayAttr::parseWithoutBraces(parser, Type{}); + } else { + if (isEmptyList) + result = DenseI32ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); + } break; case 64: - if (isEmptyList) - result = DenseI64ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + if (type.isUnsignedInteger()) { + if (isEmptyList) + result = DenseUI64ArrayAttr::get(parser.getContext(), {}); + else + result = DenseUI64ArrayAttr::parseWithoutBraces(parser, Type{}); + } else { + if (isEmptyList) + result = DenseI64ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + } break; default: emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -735,7 +735,6 @@ const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } - const int8_t * DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); @@ -752,6 +751,22 @@ DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } +const uint8_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const uint16_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const uint32_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} +const uint64_t * +DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } @@ -781,6 +796,18 @@ case DenseArrayBaseAttr::EltType::I64: this->cast().printWithoutBraces(os); return; + case DenseArrayBaseAttr::EltType::UI8: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::UI16: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::UI32: + this->cast().printWithoutBraces(os); + return; + case DenseArrayBaseAttr::EltType::UI64: + this->cast().printWithoutBraces(os); + return; case DenseArrayBaseAttr::EltType::F32: this->cast().printWithoutBraces(os); return; @@ -815,6 +842,13 @@ llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); } +/// Specialization for uint8_t for forcing printing as number instead of chars. +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + ArrayRef values{*this}; + llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); +} + template void DenseArrayAttr::print(raw_ostream &os) const { os << "["; @@ -845,8 +879,7 @@ /// Parse a DenseArrayAttr without the braces: `1, 2, 3` template -Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, - Type odsType) { +Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, Type type) { SmallVector data; if (failed(parser.parseCommaSeparatedList([&]() { T value; @@ -861,13 +894,13 @@ /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` template -Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { +Attribute DenseArrayAttr::parse(AsmParser &parser, Type type) { if (parser.parseLSquare()) return {}; // Handle empty list case. if (succeeded(parser.parseOptionalRSquare())) return get(parser.getContext(), {}); - Attribute result = parseWithoutBraces(parser, odsType); + Attribute result = parseWithoutBraces(parser, type); if (parser.parseRSquare()) return {}; return result; @@ -927,6 +960,46 @@ } }; template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::UI8; + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { + return RankedTensorType::get( + shape, IntegerType::get(context, 8, + IntegerType::SignednessSemantics::Unsigned)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::UI16; + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { + return RankedTensorType::get( + shape, IntegerType::get(context, 16, + IntegerType::SignednessSemantics::Unsigned)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::UI32; + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { + return RankedTensorType::get( + shape, IntegerType::get(context, 32, + IntegerType::SignednessSemantics::Unsigned)); + } +}; +template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::UI64; + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { + return RankedTensorType::get( + shape, IntegerType::get(context, 64, + IntegerType::SignednessSemantics::Unsigned)); + } +}; +template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; static ShapedType getShapedType(MLIRContext *context, @@ -973,6 +1046,10 @@ template class DenseArrayAttr; template class DenseArrayAttr; template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; template class DenseArrayAttr; template class DenseArrayAttr; } // namespace detail diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -37,6 +37,14 @@ arith.constant [:i32 10, 11, -12, 13, 14] // expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} arith.constant [:i64 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `uint8_t`: 10, 11, 244, 13, 14}} +arith.constant [:ui8 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `uint16_t`: 10, 11, 65524, 13, 14}} +arith.constant [:ui16 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `uint32_t`: 10, 11, 4294967284, 13, 14}} +arith.constant [:ui32 10, 11, -12, 13, 14] +// expected-error@below {{Test iterating `uint64_t`: 10, 11, 18446744073709551604, 13, 14}} +arith.constant [:ui64 10, 11, -12, 13, 14] // expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} arith.constant [:f32 10., 11., -12., 13., 14.] // expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -62,6 +62,18 @@ case DenseArrayBaseAttr::EltType::I64: testElementsAttrIteration(op, elementsAttr, "int64_t"); break; + case DenseArrayBaseAttr::EltType::UI8: + testElementsAttrIteration(op, elementsAttr, "uint8_t"); + break; + case DenseArrayBaseAttr::EltType::UI16: + testElementsAttrIteration(op, elementsAttr, "uint16_t"); + break; + case DenseArrayBaseAttr::EltType::UI32: + testElementsAttrIteration(op, elementsAttr, "uint32_t"); + break; + case DenseArrayBaseAttr::EltType::UI64: + testElementsAttrIteration(op, elementsAttr, "uint64_t"); + break; case DenseArrayBaseAttr::EltType::F32: testElementsAttrIteration(op, elementsAttr, "float"); break;