diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1526,7 +1526,12 @@ // Check for the splat case. if (attr.isSplat()) { - processElt(*attr.begin(), /*index=*/0); + if (bitWidth == 1) { + // Handle the special encoding of splat of bool. + data[0] = mapping(*attr.begin()).isZero() ? 0 : -1; + } else { + processElt(*attr.begin(), /*index=*/0); + } return newArrayType; } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -209,6 +209,40 @@ auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); EXPECT_TRUE(attr.getValues()[0] == value); } + +TEST(DenseSplatMapValuesTest, I32ToTrue) { + MLIRContext context; + const int elementValue = 12; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_TRUE(attr.getSplatValue().getValue()); +} + +TEST(DenseSplatMapValuesTest, I32ToFalse) { + MLIRContext context; + const int elementValue = 0; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_FALSE(attr.getSplatValue().getValue()); +} } // namespace //===----------------------------------------------------------------------===//