diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -791,8 +791,11 @@ static bool classof(Attribute attr); }; template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; +template <> void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; +extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; extern template class DenseArrayAttr; @@ -802,6 +805,7 @@ } // namespace detail // Public name for all the supported DenseArrayAttr +using DenseBoolArrayAttr = detail::DenseArrayAttr; using DenseI8ArrayAttr = detail::DenseArrayAttr; using DenseI16ArrayAttr = detail::DenseArrayAttr; using DenseI32ArrayAttr = detail::DenseArrayAttr; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -180,7 +180,7 @@ ArrayRefParameter<"char">:$elements); let extraClassDeclaration = [{ // All possible supported element type. - enum class EltType { I8, I16, I32, I64, F32, F64 }; + enum class EltType { I1, I8, I16, I32, I64, F32, F64 }; /// Allow implicit conversion to ElementsAttr. operator ElementsAttr() const { @@ -189,7 +189,8 @@ /// ElementsAttr implementation. using ContiguousIterableTypesT = - std::tuple; + std::tuple; + const bool *value_begin_impl(OverloadToken) const; const int8_t *value_begin_impl(OverloadToken) const; const int16_t *value_begin_impl(OverloadToken) const; const int32_t *value_begin_impl(OverloadToken) const; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1282,6 +1282,7 @@ let storageType = "::mlir::" # denseAttrName; let returnType = "::llvm::ArrayRef<" # cppType # ">"; } +def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">; def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">; def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -845,6 +845,12 @@ if (auto intType = type.dyn_cast()) { switch (type.getIntOrFloatBitWidth()) { + case 1: + if (isEmptyList) + result = DenseBoolArrayAttr::get(parser.getContext(), {}); + else + result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{}); + break; case 8: if (isEmptyList) result = DenseI8ArrayAttr::get(parser.getContext(), {}); @@ -870,7 +876,7 @@ result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); break; default: - emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; + emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type; return {}; } } else if (auto floatType = type.dyn_cast()) { diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -238,6 +238,15 @@ /// Parse an optional integer value from the stream. OptionalParseResult Parser::parseOptionalInteger(APInt &result) { + // Parse `false` and `true` keywords as 0 and 1 respectively. + if (consumeIf(Token::kw_false)) { + result = false; + return success(); + } else if (consumeIf(Token::kw_true)) { + result = true; + return success(); + } + Token curToken = getToken(); if (curToken.isNot(Token::integer, Token::minus)) return llvm::None; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1860,26 +1860,7 @@ } } else if (auto denseArrayAttr = attr.dyn_cast()) { typeElision = AttrTypeElision::Must; - switch (denseArrayAttr.getElementType()) { - case DenseArrayBaseAttr::EltType::I8: - os << "[:i8"; - break; - case DenseArrayBaseAttr::EltType::I16: - os << "[:i16"; - break; - case DenseArrayBaseAttr::EltType::I32: - os << "[:i32"; - break; - case DenseArrayBaseAttr::EltType::I64: - os << "[:i64"; - break; - case DenseArrayBaseAttr::EltType::F32: - os << "[:f32"; - break; - case DenseArrayBaseAttr::EltType::F64: - os << "[:f64"; - break; - } + os << "[:" << denseArrayAttr.getType().getElementType(); if (denseArrayAttr.size()) os << " "; denseArrayAttr.printWithoutBraces(os); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -732,6 +732,9 @@ ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; } +const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { + return cast().asArrayRef().begin(); +} const int8_t * DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); @@ -762,6 +765,9 @@ void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { switch (getElementType()) { + case DenseArrayBaseAttr::EltType::I1: + this->cast().printWithoutBraces(os); + return; case DenseArrayBaseAttr::EltType::I8: this->cast().printWithoutBraces(os); return; @@ -797,15 +803,20 @@ template void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { - ArrayRef values{*this}; - llvm::interleaveComma(values, os); + llvm::interleaveComma(asArrayRef(), os); +} + +/// Specialization for bool to print `true` or `false`. +template <> +void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { + llvm::interleaveComma(asArrayRef(), os, + [&](bool v) { os << (v ? "true" : "false"); }); } /// Specialization for int8_t for forcing printing as number instead of chars. template <> void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { - ArrayRef values{*this}; - llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); + llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; }); } template @@ -816,7 +827,7 @@ } /// Parse a single element: generic template for int types, specialized for -/// floating points below. +/// floating point and boolean values below. template static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { return parser.parseInteger(value); @@ -880,6 +891,14 @@ template struct denseArrayAttrEltTypeBuilder; template <> +struct denseArrayAttrEltTypeBuilder { + constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1; + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { + return RankedTensorType::get(shape, IntegerType::get(context, 1)); + } +}; +template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; static ShapedType getShapedType(MLIRContext *context, @@ -953,6 +972,7 @@ namespace mlir { namespace detail { // Explicit instantiation for all the supported DenseArrayAttr. +template class DenseArrayAttr; template class DenseArrayAttr; template class DenseArrayAttr; template class DenseArrayAttr; diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -521,13 +521,15 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: func @dense_array_attr -func.func @dense_array_attr() attributes{ +func.func @dense_array_attr() attributes { // CHECK-SAME: emptyf32attr = [:f32], emptyf32attr = [:f32], // CHECK-SAME: emptyf64attr = [:f64], emptyf64attr = [:f64], // CHECK-SAME: emptyi16attr = [:i16], emptyi16attr = [:i16], +// CHECK-SAME: emptyi1attr = [:i1], + emptyi1attr = [:i1], // CHECK-SAME: emptyi32attr = [:i32], emptyi32attr = [:i32], // CHECK-SAME: emptyi64attr = [:i64], @@ -540,6 +542,8 @@ f64attr = [:f64 -142.], // CHECK-SAME: i16attr = [:i16 3, 5, -4, 10], i16attr = [:i16 3, 5, -4, 10], +// CHECK-SAME: i1attr = [:i1 true, false, true], + i1attr = [:i1 true, false, true], // CHECK-SAME: i32attr = [:i32 1024, 453, -6435], i32attr = [:i32 1024, 453, -6435], // CHECK-SAME: i64attr = [:i64 -142], @@ -549,6 +553,8 @@ } { // CHECK: test.dense_array_attr test.dense_array_attr +// CHECK-SAME: i1attr = [true, false, true] + i1attr = [true, false, true] // CHECK-SAME: i8attr = [1, -2, 3] i8attr = [1, -2, 3] // CHECK-SAME: i16attr = [3, 5, -4, 10] diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -27,6 +27,8 @@ // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> +// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}} +arith.constant [:i1 true, false, true, false, true, false] // expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} arith.constant [:i8 10, 11, -12, 13, 14] // expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -272,6 +272,7 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { let arguments = (ins + DenseBoolArrayAttr:$i1attr, DenseI8ArrayAttr:$i8attr, DenseI16ArrayAttr:$i16attr, DenseI32ArrayAttr:$i32attr, @@ -281,10 +282,9 @@ DenseI32ArrayAttr:$emptyattr ); let assemblyFormat = [{ - `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr - `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr - `emptyattr` `=` $emptyattr - attr-dict + `i1attr` `=` $i1attr `i8attr` `=` $i8attr `i16attr` `=` $i16attr + `i32attr` `=` $i32attr `i64attr` `=` $i64attr `f32attr` `=` $f32attr + `f64attr` `=` $f64attr `emptyattr` `=` $emptyattr attr-dict }]; } diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -43,6 +43,9 @@ if (auto concreteAttr = attr.getValue().dyn_cast()) { switch (concreteAttr.getElementType()) { + case DenseArrayBaseAttr::EltType::I1: + testElementsAttrIteration(op, elementsAttr, "bool"); + break; case DenseArrayBaseAttr::EltType::I8: testElementsAttrIteration(op, elementsAttr, "int8_t"); break;