diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -76,22 +76,7 @@ /// Compare this storage instance with the provided key. bool operator==(const KeyTy &key) const { - if (key.type != type) - return false; - - // For boolean splats we need to explicitly check that the first bit is the - // same. Boolean values are packed at the bit level, and even though a splat - // is detected the rest of the bits in the first byte may differ from the - // splat value. - if (key.type.getElementType().isInteger(1)) { - if (key.isSplat != isSplat) - return false; - if (isSplat) - return (key.data.front() & 1) == data.front(); - } - - // Otherwise, we can default to just checking the data. - return key.data == data; + return key.type == type && key.data == data; } /// Construct a key from a shaped type, raw data buffer, and a flag that @@ -105,8 +90,12 @@ // If the data is already known to be a splat, the key hash value is // directly the data buffer. - if (isKnownSplat) + bool isBoolData = ty.getElementType().isInteger(1); + if (isKnownSplat) { + if (isBoolData) + return getKeyForSplatBoolData(ty, data[0] != 0); return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); + } // Otherwise, we need to check if the data corresponds to a splat or not. @@ -115,7 +104,7 @@ assert(numElements != 1 && "splat of 1 element should already be detected"); // Handle boolean values directly as they are packed to 1-bit. - if (ty.getElementType().isInteger(1) == 1) + if (isBoolData) return getKeyForBoolData(ty, data, numElements); size_t elementWidth = getDenseElementBitWidth(ty.getElementType()); @@ -144,12 +133,9 @@ ArrayRef splatData = data; bool splatValue = splatData.front() & 1; - // Helper functor to generate a KeyTy for a boolean splat value. - auto generateSplatKey = [=] { - return KeyTy(ty, data.take_front(1), - llvm::hash_value(ArrayRef(splatValue ? 1 : 0)), - /*isSplat=*/true); - }; + // Check the simple case where the data matches the known splat value. + if (splatData == ArrayRef(splatValue ? kSplatTrue : kSplatFalse)) + return getKeyForSplatBoolData(ty, splatValue); // Handle the case where the potential splat value is 1 and the number of // elements is non 8-bit aligned. @@ -162,17 +148,24 @@ // If this is the only element, the data is known to be a splat. if (splatData.size() == 1) - return generateSplatKey(); + return getKeyForSplatBoolData(ty, splatValue); splatData = splatData.drop_back(); } // Check that the data buffer corresponds to a splat of the proper mask. char mask = splatValue ? ~0 : 0; return llvm::all_of(splatData, [mask](char c) { return c == mask; }) - ? generateSplatKey() + ? getKeyForSplatBoolData(ty, splatValue) : KeyTy(ty, data, llvm::hash_value(data)); } + /// Return a key to use for a boolean splat of the given value. + static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) { + const char &splatData = splatValue ? kSplatTrue : kSplatFalse; + return KeyTy(type, splatData, llvm::hash_value(splatData), + /*isSplat=*/true); + } + /// Hash the key for the storage. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_combine(key.type, key.hashCode); @@ -188,10 +181,6 @@ char *rawData = reinterpret_cast( allocator.allocate(data.size(), alignof(uint64_t))); std::memcpy(rawData, data.data(), data.size()); - - // If this is a boolean splat, make sure only the first bit is used. - if (key.isSplat && key.type.getElementType().isInteger(1)) - rawData[0] &= 1; copy = ArrayRef(rawData, data.size()); } @@ -200,6 +189,10 @@ } ArrayRef data; + + /// The values used to denote a boolean splat value. + static constexpr char kSplatTrue = ~0; + static constexpr char kSplatFalse = 0; }; /// An attribute representing a reference to a dense vector or tensor object diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -58,6 +58,20 @@ detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); EXPECT_EQ(detectedSplat, falseSplat); } +TEST(DenseSplatTest, BoolSplatRawRoundtrip) { + MLIRContext context; + IntegerType boolTy = IntegerType::get(&context, 1); + RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); + + // Check that splat booleans properly round trip via the raw API. + DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); + EXPECT_TRUE(trueSplat.isSplat()); + DenseElementsAttr trueSplatFromRaw = + DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); + EXPECT_TRUE(trueSplatFromRaw.isSplat()); + + EXPECT_EQ(trueSplat, trueSplatFromRaw); +} TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56;