diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -193,8 +193,11 @@ ArrayRef data; /// The values used to denote a boolean splat value. - static constexpr char kSplatTrue = ~0; - static constexpr char kSplatFalse = 0; + // This is not using constexpr declaration due to compilation failure + // encountered with MSVC where it would inline these values, which makes it + // unsafe to refer by reference in KeyTy. + static const char kSplatTrue; + static const char kSplatFalse; }; /// An attribute representing a reference to a dense vector or tensor object diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -438,6 +438,9 @@ // DenseElementsAttr Utilities //===----------------------------------------------------------------------===// +const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0; +const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0; + /// Get the bitwidth of a dense element type within the buffer. /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. static size_t getDenseElementStorageWidth(size_t origWidth) { diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -74,6 +74,20 @@ EXPECT_EQ(trueSplat, trueSplatFromRaw); } +TEST(DenseSplatTest, BoolSplatSmall) { + MLIRContext context; + Builder builder(&context); + + // Check that splats that don't fill entire byte are handled properly. + auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); + std::vector data{0b00001111}; + auto trueSplatFromRaw = + DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); + EXPECT_TRUE(trueSplatFromRaw.isSplat()); + DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); + EXPECT_EQ(trueSplat, trueSplatFromRaw); +} + TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56;