diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -392,7 +392,45 @@ return getSplatValue().template cast(); } - /// Return the held element values as a range of integer or floating-point + /// Try to get an iterator of the given type to the start of the held element + /// values. Return failure if the type cannot be iterated. + template + auto try_value_begin() const { + auto range = tryGetValues(); + using iterator = decltype(range->begin()); + return failed(range) ? FailureOr(failure()) : range->begin(); + } + + /// Try to get an iterator of the given type to the end of the held element + /// values. Return failure if the type cannot be iterated. + template + auto try_value_end() const { + auto range = tryGetValues(); + using iterator = decltype(range->begin()); + return failed(range) ? FailureOr(failure()) : range->end(); + } + + /// Return the held element values as a range of the given type. + template + auto getValues() const { + auto range = tryGetValues(); + assert(succeeded(range) && "element type cannot be iterated"); + return std::move(*range); + } + + /// Get an iterator of the given type to the start of the held element values. + template + auto value_begin() const { + return getValues().begin(); + } + + /// Get an iterator of the given type to the end of the held element values. + template + auto value_end() const { + return getValues().end(); + } + + /// Try to get the held element values as a range of integer or floating-point /// values. template using IntFloatValueTemplateCheckT = @@ -400,28 +438,18 @@ std::numeric_limits::is_integer) || is_valid_cpp_fp_type::value>::type; template > - iterator_range_impl> getValues() const { - assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); + FailureOr>> tryGetValues() const { + if (!isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)) + return failure(); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {getType(), ElementIterator(rawData, splat, 0), - ElementIterator(rawData, splat, getNumElements())}; - } - template > - ElementIterator value_begin() const { - assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - return ElementIterator(getRawData().data(), isSplat(), 0); - } - template > - ElementIterator value_end() const { - assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - return ElementIterator(getRawData().data(), isSplat(), getNumElements()); + return iterator_range_impl>( + getType(), ElementIterator(rawData, splat, 0), + ElementIterator(rawData, splat, getNumElements())); } - /// Return the held element values as a range of std::complex. + /// Try to get the held element values as a range of std::complex. template using ComplexValueTemplateCheckT = typename std::enable_if::value && @@ -429,70 +457,45 @@ is_valid_cpp_fp_type::value)>::type; template > - iterator_range_impl> getValues() const { - assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); + FailureOr>> tryGetValues() const { + if (!isValidComplex(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)) + return failure(); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {getType(), ElementIterator(rawData, splat, 0), - ElementIterator(rawData, splat, getNumElements())}; - } - template > - ElementIterator value_begin() const { - assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - return ElementIterator(getRawData().data(), isSplat(), 0); - } - template > - ElementIterator value_end() const { - assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, - std::numeric_limits::is_signed)); - return ElementIterator(getRawData().data(), isSplat(), getNumElements()); + return iterator_range_impl>( + getType(), ElementIterator(rawData, splat, 0), + ElementIterator(rawData, splat, getNumElements())); } - /// Return the held element values as a range of StringRef. + /// Try to get the held element values as a range of StringRef. template using StringRefValueTemplateCheckT = typename std::enable_if::value>::type; template > - iterator_range_impl> getValues() const { + FailureOr>> + tryGetValues() const { auto stringRefs = getRawStringData(); const char *ptr = reinterpret_cast(stringRefs.data()); bool splat = isSplat(); - return {getType(), ElementIterator(ptr, splat, 0), - ElementIterator(ptr, splat, getNumElements())}; - } - template > - ElementIterator value_begin() const { - const char *ptr = reinterpret_cast(getRawStringData().data()); - return ElementIterator(ptr, isSplat(), 0); - } - template > - ElementIterator value_end() const { - const char *ptr = reinterpret_cast(getRawStringData().data()); - return ElementIterator(ptr, isSplat(), getNumElements()); + return iterator_range_impl>( + getType(), ElementIterator(ptr, splat, 0), + ElementIterator(ptr, splat, getNumElements())); } - /// Return the held element values as a range of Attributes. + /// Try to get the held element values as a range of Attributes. template using AttributeValueTemplateCheckT = typename std::enable_if::value>::type; template > - iterator_range_impl getValues() const { - return {getType(), value_begin(), value_end()}; - } - template > - AttributeElementIterator value_begin() const { - return AttributeElementIterator(*this, 0); - } - template > - AttributeElementIterator value_end() const { - return AttributeElementIterator(*this, getNumElements()); + FailureOr> + tryGetValues() const { + return iterator_range_impl( + getType(), AttributeElementIterator(*this, 0), + AttributeElementIterator(*this, getNumElements())); } - /// Return the held element values a range of T, where T is a derived + /// Try to get the held element values a range of T, where T is a derived /// attribute type. template using DerivedAttrValueTemplateCheckT = @@ -510,115 +513,71 @@ T mapElement(Attribute attr) const { return attr.cast(); } }; template > - iterator_range_impl> getValues() const { + FailureOr>> + tryGetValues() const { using DerivedIterT = DerivedAttributeElementIterator; - return {getType(), DerivedIterT(value_begin()), - DerivedIterT(value_end())}; - } - template > - DerivedAttributeElementIterator value_begin() const { - return {value_begin()}; - } - template > - DerivedAttributeElementIterator value_end() const { - return {value_end()}; + return iterator_range_impl( + getType(), DerivedIterT(value_begin()), + DerivedIterT(value_end())); } - /// Return the held element values as a range of bool. The element type of + /// Try to get the held element values as a range of bool. The element type of /// this attribute must be of integer type of bitwidth 1. template using BoolValueTemplateCheckT = typename std::enable_if::value>::type; template > - iterator_range_impl getValues() const { - assert(isValidBool() && "bool is not the value of this elements attribute"); - return {getType(), BoolElementIterator(*this, 0), - BoolElementIterator(*this, getNumElements())}; - } - template > - BoolElementIterator value_begin() const { - assert(isValidBool() && "bool is not the value of this elements attribute"); - return BoolElementIterator(*this, 0); - } - template > - BoolElementIterator value_end() const { - assert(isValidBool() && "bool is not the value of this elements attribute"); - return BoolElementIterator(*this, getNumElements()); + FailureOr> tryGetValues() const { + if (!isValidBool()) + return failure(); + return iterator_range_impl( + getType(), BoolElementIterator(*this, 0), + BoolElementIterator(*this, getNumElements())); } - /// Return the held element values as a range of APInts. The element type of - /// this attribute must be of integer type. + /// Try to get the held element values as a range of APInts. The element type + /// of this attribute must be of integer type. template using APIntValueTemplateCheckT = typename std::enable_if::value>::type; template > - iterator_range_impl getValues() const { - assert(getElementType().isIntOrIndex() && "expected integral type"); - return {getType(), raw_int_begin(), raw_int_end()}; - } - template > - IntElementIterator value_begin() const { - assert(getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_begin(); - } - template > - IntElementIterator value_end() const { - assert(getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_end(); + FailureOr> tryGetValues() const { + if (!getElementType().isIntOrIndex()) + return failure(); + return iterator_range_impl(getType(), raw_int_begin(), + raw_int_end()); } - /// Return the held element values as a range of complex APInts. The element - /// type of this attribute must be a complex of integer type. + /// Try to get the held element values as a range of complex APInts. The + /// element type of this attribute must be a complex of integer type. template using ComplexAPIntValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - iterator_range_impl getValues() const { - return getComplexIntValues(); - } - template > - ComplexIntElementIterator value_begin() const { - return complex_value_begin(); - } - template > - ComplexIntElementIterator value_end() const { - return complex_value_end(); + FailureOr> + tryGetValues() const { + return tryGetComplexIntValues(); } - /// Return the held element values as a range of APFloat. The element type of - /// this attribute must be of float type. + /// Try to get the held element values as a range of APFloat. The element type + /// of this attribute must be of float type. template using APFloatValueTemplateCheckT = typename std::enable_if::value>::type; template > - iterator_range_impl getValues() const { - return getFloatValues(); - } - template > - FloatElementIterator value_begin() const { - return float_value_begin(); - } - template > - FloatElementIterator value_end() const { - return float_value_end(); + FailureOr> tryGetValues() const { + return tryGetFloatValues(); } - /// Return the held element values as a range of complex APFloat. The element - /// type of this attribute must be a complex of float type. + /// Try to get the held element values as a range of complex APFloat. The + /// element type of this attribute must be a complex of float type. template using ComplexAPFloatValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - iterator_range_impl getValues() const { - return getComplexFloatValues(); - } - template > - ComplexFloatElementIterator value_begin() const { - return complex_float_value_begin(); - } - template > - ComplexFloatElementIterator value_end() const { - return complex_float_value_end(); + FailureOr> + tryGetValues() const { + return tryGetComplexFloatValues(); } /// Return the raw storage data held by this attribute. Users should generally @@ -687,16 +646,12 @@ IntElementIterator raw_int_end() const { return IntElementIterator(*this, getNumElements()); } - iterator_range_impl getComplexIntValues() const; - ComplexIntElementIterator complex_value_begin() const; - ComplexIntElementIterator complex_value_end() const; - iterator_range_impl getFloatValues() const; - FloatElementIterator float_value_begin() const; - FloatElementIterator float_value_end() const; - iterator_range_impl - getComplexFloatValues() const; - ComplexFloatElementIterator complex_float_value_begin() const; - ComplexFloatElementIterator complex_float_value_end() const; + FailureOr> + tryGetComplexIntValues() const; + FailureOr> + tryGetFloatValues() const; + FailureOr> + tryGetComplexFloatValues() const; /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the @@ -973,8 +928,8 @@ function_ref mapping) const; /// Iterator access to the float element values. - iterator begin() const { return float_value_begin(); } - iterator end() const { return float_value_end(); } + iterator begin() const { return tryGetFloatValues()->begin(); } + iterator end() const { return tryGetFloatValues()->end(); } /// Method for supporting type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr); @@ -1026,12 +981,15 @@ //===----------------------------------------------------------------------===// template -auto SparseElementsAttr::value_begin() const -> iterator { +auto SparseElementsAttr::try_value_begin_impl(OverloadToken) const + -> FailureOr> { auto zeroValue = getZeroValue(); - auto valueIt = getValues().value_begin(); + auto valueIt = getValues().try_value_begin(); + if (failed(valueIt)) + return failure(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); std::function mapFn = - [flatSparseIndices{flatSparseIndices}, valueIt{std::move(valueIt)}, + [flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)}, zeroValue{std::move(zeroValue)}](ptrdiff_t index) { // Try to map the current index to one of the sparse indices. for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -292,7 +292,7 @@ /// ElementsAttr. template auto try_value_begin_impl(OverloadToken) const { - return ::mlir::success(value_begin()); + return try_value_begin(); } /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of @@ -422,7 +422,7 @@ /// ElementsAttr. template auto try_value_begin_impl(OverloadToken) const { - return ::mlir::success(value_begin()); + return try_value_begin(); } protected: @@ -889,23 +889,17 @@ StringRef >; using ElementsAttr::Trait::getValues; - - /// Provide a `try_value_begin_impl` to enable iteration within - /// ElementsAttr. - template - auto try_value_begin_impl(OverloadToken) const { - return ::mlir::success(value_begin()); - } + using ElementsAttr::Trait::value_begin; template using iterator = llvm::mapped_iterator(0, 0))::iterator, std::function>; - /// Return the values of this attribute in the form of the given type 'T'. - /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, - /// etc. - template iterator value_begin() const; + /// Provide a `try_value_begin_impl` to enable iteration within + /// ElementsAttr. + template + FailureOr> try_value_begin_impl(OverloadToken) const; private: /// Get a zero APFloat for the given sparse attribute. diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1160,68 +1160,42 @@ } /// Return if the given complex type has an integer element type. -LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { +static bool isComplexOfIntType(Type type) { return type.cast().getElementType().isa(); } -auto DenseElementsAttr::getComplexIntValues() const - -> iterator_range_impl { - assert(isComplexOfIntType(getElementType()) && - "expected complex integral type"); - return {getType(), ComplexIntElementIterator(*this, 0), - ComplexIntElementIterator(*this, getNumElements())}; -} -auto DenseElementsAttr::complex_value_begin() const - -> ComplexIntElementIterator { - assert(isComplexOfIntType(getElementType()) && - "expected complex integral type"); - return ComplexIntElementIterator(*this, 0); -} -auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { - assert(isComplexOfIntType(getElementType()) && - "expected complex integral type"); - return ComplexIntElementIterator(*this, getNumElements()); -} - -/// Return the held element values as a range of APFloat. The element type of -/// this attribute must be of float type. -auto DenseElementsAttr::getFloatValues() const - -> iterator_range_impl { - auto elementType = getElementType().cast(); - const auto &elementSemantics = elementType.getFloatSemantics(); - return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), - FloatElementIterator(elementSemantics, raw_int_end())}; -} -auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { - auto elementType = getElementType().cast(); - return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); -} -auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { - auto elementType = getElementType().cast(); - return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); -} - -auto DenseElementsAttr::getComplexFloatValues() const - -> iterator_range_impl { - Type eltTy = getElementType().cast().getElementType(); - assert(eltTy.isa() && "expected complex float type"); - const auto &semantics = eltTy.cast().getFloatSemantics(); - return {getType(), - {semantics, {*this, 0}}, - {semantics, {*this, static_cast(getNumElements())}}}; -} -auto DenseElementsAttr::complex_float_value_begin() const - -> ComplexFloatElementIterator { - Type eltTy = getElementType().cast().getElementType(); - assert(eltTy.isa() && "expected complex float type"); - return {eltTy.cast().getFloatSemantics(), {*this, 0}}; -} -auto DenseElementsAttr::complex_float_value_end() const - -> ComplexFloatElementIterator { - Type eltTy = getElementType().cast().getElementType(); - assert(eltTy.isa() && "expected complex float type"); - return {eltTy.cast().getFloatSemantics(), - {*this, static_cast(getNumElements())}}; +auto DenseElementsAttr::tryGetComplexIntValues() const + -> FailureOr> { + if (!isComplexOfIntType(getElementType())) + return failure(); + return iterator_range_impl( + getType(), ComplexIntElementIterator(*this, 0), + ComplexIntElementIterator(*this, getNumElements())); +} + +auto DenseElementsAttr::tryGetFloatValues() const + -> FailureOr> { + auto eltTy = getElementType().dyn_cast(); + if (!eltTy) + return failure(); + const auto &elementSemantics = eltTy.getFloatSemantics(); + return iterator_range_impl( + getType(), FloatElementIterator(elementSemantics, raw_int_begin()), + FloatElementIterator(elementSemantics, raw_int_end())); +} + +auto DenseElementsAttr::tryGetComplexFloatValues() const + -> FailureOr> { + auto complexTy = getElementType().dyn_cast(); + if (!complexTy) + return failure(); + auto eltTy = complexTy.getElementType().dyn_cast(); + if (!eltTy) + return failure(); + const auto &semantics = eltTy.getFloatSemantics(); + return iterator_range_impl( + getType(), {semantics, {*this, 0}}, + {semantics, {*this, static_cast(getNumElements())}}); } /// Return the raw storage data held by this attribute.