diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1336,8 +1336,11 @@ if (eltType.isa()) return FloatAttr::get(eltType, 0); + // Handle string type. + if (getValues().isa()) + return StringAttr::get("", eltType); + // Otherwise, this is an integer. - // TODO: Handle StringAttr here. return IntegerAttr::get(eltType, 0); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -205,4 +205,50 @@ EXPECT_TRUE(attr.getValue({0}) == value); } +TEST(SparseElementsAttrTest, GetZero) { + MLIRContext context; + context.allowUnregisteredDialects(); + + IntegerType intTy = IntegerType::get(&context, 32); + FloatType floatTy = FloatType::getF32(&context); + Type stringTy = OpaqueType::get(Identifier::get("test", &context), "string"); + + ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); + ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); + ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); + + auto indicesType = + RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); + auto indices = + DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); + + RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); + auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); + + RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); + auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); + + RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); + auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); + + auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); + auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); + auto sparseString = + SparseElementsAttr::get(tensorString, indices, stringValue); + + // Only index (0, 0) contains an element, others are supposed to return + // the zero/empty value. + auto zeroIntValue = sparseInt.getValue({1, 1}); + EXPECT_EQ(zeroIntValue.cast().getInt(), 0); + EXPECT_TRUE(zeroIntValue.getType() == intTy); + + auto zeroFloatValue = sparseFloat.getValue({1, 1}); + EXPECT_EQ(zeroFloatValue.cast().getValueAsDouble(), 0.0f); + EXPECT_TRUE(zeroFloatValue.getType() == floatTy); + + auto zeroStringValue = sparseString.getValue({1, 1}); + EXPECT_TRUE(zeroStringValue.cast().getValue().empty()); + EXPECT_TRUE(zeroStringValue.getType() == stringTy); +} + } // end namespace