diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -54,33 +54,36 @@ using NonContiguousIterableTypesT = std::tuple; ``` - * Provide a `iterator value_begin_impl(OverloadToken) const` overload for - each iterable type + * Provide a `FailureOr try_value_begin_impl(OverloadToken) const` + overload for each iterable type These overloads should return an iterator to the start of the range for the - respective iterable type. Consider the example i64 elements attribute - described in the previous section. This attribute may define the - value_begin_impl overloads like so: + respective iterable type or fail if the type cannot be iterated. Consider + the example i64 elements attribute described in the previous section. This + attribute may define the value_begin_impl overloads like so: ```c++ /// Provide begin iterators for the various iterable types. /// * uint64_t - auto value_begin_impl(OverloadToken) const { + FailureOr + value_begin_impl(OverloadToken) const { return getElements().begin(); } /// * APInt auto value_begin_impl(OverloadToken) const { - return llvm::map_range(getElements(), [=](uint64_t value) { + auto it = llvm::map_range(getElements(), [=](uint64_t value) { return llvm::APInt(/*numBits=*/64, value); }).begin(); + return FailureOr(std::move(it)); } /// * Attribute auto value_begin_impl(OverloadToken) const { mlir::Type elementType = getType().getElementType(); - return llvm::map_range(getElements(), [=](uint64_t value) { + auto it = llvm::map_range(getElements(), [=](uint64_t value) { return mlir::IntegerAttr::get(elementType, llvm::APInt(/*numBits=*/64, value)); }).begin(); + return FailureOr(std::move(it)); } ``` @@ -244,18 +247,22 @@ /*isSplat=*/false, nullptr); } - auto valueIt = $_attr.value_begin_impl(OverloadToken()); + auto valueIt = $_attr.try_value_begin_impl(OverloadToken()); + if (::mlir::failed(valueIt)) + return ::mlir::failure(); return ::mlir::detail::ElementsAttrIndexer::contiguous( - $_attr.isSplat(), &*valueIt); + $_attr.isSplat(), &**valueIt); } /// Build an indexer for the given type `T`, which is represented via a /// non-contiguous range. template ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( /*isContiguous*/std::false_type) const { - auto valueIt = $_attr.value_begin_impl(OverloadToken()); + auto valueIt = $_attr.try_value_begin_impl(OverloadToken()); + if (::mlir::failed(valueIt)) + return ::mlir::failure(); return ::mlir::detail::ElementsAttrIndexer::nonContiguous( - $_attr.isSplat(), valueIt); + $_attr.isSplat(), *valueIt); } public: @@ -275,7 +282,7 @@ /// type `T`. template auto value_begin() const { - return $_attr.value_begin_impl(OverloadToken()); + return *$_attr.try_value_begin_impl(OverloadToken()); } /// Return the elements of this attribute as a value of type 'T'. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -200,13 +200,20 @@ /// ElementsAttr implementation. using ContiguousIterableTypesT = std::tuple; - const bool *value_begin_impl(OverloadToken) const; - const int8_t *value_begin_impl(OverloadToken) const; - const int16_t *value_begin_impl(OverloadToken) const; - const int32_t *value_begin_impl(OverloadToken) const; - const int64_t *value_begin_impl(OverloadToken) const; - const float *value_begin_impl(OverloadToken) const; - const double *value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; + FailureOr + try_value_begin_impl(OverloadToken) const; /// Printer for the short form: will dispatch to the appropriate subclass. void print(AsmPrinter &printer) const; @@ -281,10 +288,11 @@ APFloat, std::complex >; - /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + /// Provide a `try_value_begin_impl` to enable iteration within + /// ElementsAttr. template - auto value_begin_impl(OverloadToken) const { - return value_begin(); + auto try_value_begin_impl(OverloadToken) const { + return ::mlir::success(value_begin()); } /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of @@ -410,10 +418,11 @@ using ContiguousIterableTypesT = std::tuple; using NonContiguousIterableTypesT = std::tuple; - /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + /// Provide a `try_value_begin_impl` to enable iteration within + /// ElementsAttr. template - auto value_begin_impl(OverloadToken) const { - return value_begin(); + auto try_value_begin_impl(OverloadToken) const { + return ::mlir::success(value_begin()); } protected: @@ -881,10 +890,11 @@ >; using ElementsAttr::Trait::getValues; - /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. + /// Provide a `try_value_begin_impl` to enable iteration within + /// ElementsAttr. template - auto value_begin_impl(OverloadToken) const { - return value_begin(); + auto try_value_begin_impl(OverloadToken) const { + return ::mlir::success(value_begin()); } template diff --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h --- a/mlir/include/mlir/Support/LogicalResult.h +++ b/mlir/include/mlir/Support/LogicalResult.h @@ -99,6 +99,13 @@ using Optional::has_value; }; +/// Wrap a value on the success path in a FailureOr of the same value type. +template >> +inline auto success(T &&t) { + return FailureOr>(std::forward(t)); +} + /// This class represents success/failure for parsing-like operations that find /// it important to chain together failable operations with `||`. This is an /// extended version of `LogicalResult` that allows for explicit conversion to diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -687,31 +687,47 @@ // DenseArrayAttr //===----------------------------------------------------------------------===// -const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int8_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int16_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int32_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int64_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const double * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); +} +FailureOr +DenseArrayBaseAttr::try_value_begin_impl(OverloadToken) const { + if (auto attr = dyn_cast()) + return attr.asArrayRef().begin(); + return failure(); } void DenseArrayBaseAttr::print(AsmPrinter &printer) const { diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -28,18 +28,24 @@ arith.constant dense<> : tensor<0xi64> // expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} arith.constant array // expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} +// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} arith.constant array // Check that we handle an external constant parsed from the config. diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -94,22 +94,23 @@ /// Provide begin iterators for the various iterable types. // * uint64_t - auto value_begin_impl(OverloadToken) const { + mlir::FailureOr + try_value_begin_impl(OverloadToken) const { return getElements().begin(); } // * Attribute - auto value_begin_impl(OverloadToken) const { + auto try_value_begin_impl(OverloadToken) const { mlir::Type elementType = getType().getElementType(); - return llvm::map_range(getElements(), [=](uint64_t value) { + return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) { return mlir::IntegerAttr::get(elementType, llvm::APInt(/*numBits=*/64, value)); - }).begin(); + }).begin()); } // * APInt - auto value_begin_impl(OverloadToken) const { - return llvm::map_range(getElements(), [=](uint64_t value) { + auto try_value_begin_impl(OverloadToken) const { + return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) { return llvm::APInt(/*numBits=*/64, value); - }).begin(); + }).begin()); } }]; let genVerifyDecl = 1; @@ -257,7 +258,8 @@ /// Provide begin iterators for the various iterable types. // * uint64_t - auto value_begin_impl(OverloadToken) const { + mlir::FailureOr + try_value_begin_impl(OverloadToken) const { return getElements().begin(); } }]; diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -65,6 +65,7 @@ .Case([&](DenseF64ArrayAttr attr) { testElementsAttrIteration(op, attr, "double"); }); + testElementsAttrIteration(op, elementsAttr, "int64_t"); continue; } testElementsAttrIteration(op, elementsAttr, "int64_t");