diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -417,6 +417,8 @@ /// return the iterable range. Otherwise, return std::nullopt. template DefaultValueCheckT>> tryGetValues() const { + if (empty()) + return std::nullopt; if (std::optional> beginIt = try_value_begin()) return iterator_range(getShapedType(), *beginIt, value_end()); return std::nullopt; @@ -429,9 +431,15 @@ template > std::optional> tryGetValues() const { auto values = tryGetValues(); + if (empty()) + return std::nullopt; if (!values) return std::nullopt; + // Check that an element can be casted as 'T', which is a derived attribute. + if (!(*(value_begin())).dyn_cast()) + return std::nullopt; + auto castFn = [](Attribute attr) { return ::llvm::cast(attr); }; return DerivedAttrValueIteratorRange( getShapedType(), diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -982,10 +982,11 @@ template auto SparseElementsAttr::try_value_begin_impl(OverloadToken) const -> FailureOr> { - auto zeroValue = getZeroValue(); auto valueIt = getValues().try_value_begin(); if (failed(valueIt)) return failure(); + + auto zeroValue = getZeroValue(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); std::function mapFn = [flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)}, diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -14,6 +14,12 @@ // expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}} arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64> +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<5xf32> + // expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}} // expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}} @@ -21,10 +27,10 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64> // Check that we don't crash on empty element attributes. -// expected-error@below {{Test iterating `int64_t`: }} -// expected-error@below {{Test iterating `uint64_t`: }} -// expected-error@below {{Test iterating `APInt`: }} -// expected-error@below {{Test iterating `IntegerAttr`: }} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} arith.constant dense<> : tensor<0xi64> // Check that we handle an external constant parsed from the config.