diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -432,6 +432,11 @@ if (!values) return std::nullopt; + // Check that an element can be casted as 'T', which is a derived attribute. + if (!empty()) + if (!(*(value_begin())).dyn_cast()) + return std::nullopt; + auto castFn = [](Attribute attr) { return ::llvm::cast(attr); }; return DerivedAttrValueIteratorRange( getShapedType(), diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -982,10 +982,11 @@ template auto SparseElementsAttr::try_value_begin_impl(OverloadToken) const -> FailureOr> { - auto zeroValue = getZeroValue(); auto valueIt = getValues().try_value_begin(); if (failed(valueIt)) return failure(); + + auto zeroValue = getZeroValue(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); std::function mapFn = [flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)}, diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -14,6 +14,12 @@ // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<5xf32> + // expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}