diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -459,6 +459,8 @@ // Verify that the rank of the indices matches the held type. auto rank = type.getRank(); + if (rank == 0 && index.size() == 1 && index[0] == 0) + return true; if (rank != static_cast(index.size())) return false; diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -192,4 +192,15 @@ testSplat(complexType, value); } +TEST(DenseScalarTest, ExtractZeroRankElement) { + MLIRContext context; + const int elementValue = 12; + IntegerType intTy = IntegerType::get(&context, 32); + Attribute value = IntegerAttr::get(intTy, elementValue); + RankedTensorType shape = RankedTensorType::get({}, intTy); + + auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); + EXPECT_TRUE(attr.getValue({0}) == value); +} + } // end namespace