diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -165,7 +165,7 @@ // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. SmallVector indices; - auto valueIt = constantValue.getValues().begin(); + auto valueIt = constantValue.value_begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -164,7 +164,7 @@ // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. SmallVector indices; - auto valueIt = constantValue.getValues().begin(); + auto valueIt = constantValue.value_begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -165,7 +165,7 @@ // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case. SmallVector indices; - auto valueIt = constantValue.getValues().begin(); + auto valueIt = constantValue.value_begin(); std::function storeElements = [&](uint64_t dimension) { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -58,8 +58,8 @@ auto lhs = operands[0].cast(); auto rhs = operands[1].cast(); - auto lhsIt = lhs.getValues().begin(); - auto rhsIt = rhs.getValues().begin(); + auto lhsIt = lhs.value_begin(); + auto rhsIt = rhs.value_begin(); SmallVector elementResults; elementResults.reserve(lhs.getNumElements()); for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -51,6 +51,9 @@ /// with static shape. ShapedType getType() const; + /// Return the element type of this ElementsAttr. + Type getElementType() const; + /// Return the value at the given index. The index is expected to refer to a /// valid element. Attribute getValue(ArrayRef index) const; @@ -65,8 +68,9 @@ /// Return the elements of this attribute as a value of type 'T'. Note: /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support /// iteration. - template - iterator_range getValues() const; + template iterator_range getValues() const; + template iterator value_begin() const; + template iterator value_end() const; /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const; @@ -417,7 +421,7 @@ T>::type getSplatValue() const { assert(isSplat() && "expected the attribute to be a splat"); - return *getValues().begin(); + return *value_begin(); } /// Return the splat value for derived attribute element types. template @@ -436,15 +440,21 @@ template T getValue(ArrayRef index) const { // Skip to the element corresponding to the flattened index. - return *std::next(getValues().begin(), getFlattenedIndex(index)); + return getFlatValue(getFlattenedIndex(index)); + } + /// Return the value at the given flattened index. + template T getFlatValue(uint64_t index) const { + return *std::next(value_begin(), index); } /// Return the held element values as a range of integer or floating-point /// values. - template ::value && - std::numeric_limits::is_integer) || - is_valid_cpp_fp_type::value>::type> + template + using IntFloatValueTemplateCheckT = + typename std::enable_if<(!std::is_same::value && + std::numeric_limits::is_integer) || + is_valid_cpp_fp_type::value>::type; + template > llvm::iterator_range> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -453,13 +463,27 @@ return {ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } + template > + ElementIterator value_begin() const { + assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); + return ElementIterator(getRawData().data(), isSplat(), 0); + } + template > + ElementIterator value_end() const { + assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); + return ElementIterator(getRawData().data(), isSplat(), getNumElements()); + } /// Return the held element values as a range of std::complex. + template + using ComplexValueTemplateCheckT = + typename std::enable_if::value && + (std::numeric_limits::is_integer || + is_valid_cpp_fp_type::value)>::type; template ::value && - (std::numeric_limits::is_integer || - is_valid_cpp_fp_type::value)>::type> + typename = ComplexValueTemplateCheckT> llvm::iterator_range> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -468,10 +492,26 @@ return {ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } + template > + ElementIterator value_begin() const { + assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); + return ElementIterator(getRawData().data(), isSplat(), 0); + } + template > + ElementIterator value_end() const { + assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); + return ElementIterator(getRawData().data(), isSplat(), getNumElements()); + } /// Return the held element values as a range of StringRef. - template ::value>::type> + template + using StringRefValueTemplateCheckT = + typename std::enable_if::value>::type; + template > llvm::iterator_range> getValues() const { auto stringRefs = getRawStringData(); const char *ptr = reinterpret_cast(stringRefs.data()); @@ -479,80 +519,156 @@ return {ElementIterator(ptr, splat, 0), ElementIterator(ptr, splat, getNumElements())}; } + template > + ElementIterator value_begin() const { + const char *ptr = reinterpret_cast(getRawStringData().data()); + return ElementIterator(ptr, isSplat(), 0); + } + template > + ElementIterator value_end() const { + const char *ptr = reinterpret_cast(getRawStringData().data()); + return ElementIterator(ptr, isSplat(), getNumElements()); + } /// Return the held element values as a range of Attributes. - llvm::iterator_range getAttributeValues() const; - template ::value>::type> + template + using AttributeValueTemplateCheckT = + typename std::enable_if::value>::type; + template > llvm::iterator_range getValues() const { - return getAttributeValues(); + return {value_begin(), value_end()}; + } + template > + AttributeElementIterator value_begin() const { + return AttributeElementIterator(*this, 0); + } + template > + AttributeElementIterator value_end() const { + return AttributeElementIterator(*this, getNumElements()); } - AttributeElementIterator attr_value_begin() const; - AttributeElementIterator attr_value_end() const; /// Return the held element values a range of T, where T is a derived /// attribute type. template + using DerivedAttrValueTemplateCheckT = + typename std::enable_if::value && + !std::is_same::value>::type; + template using DerivedAttributeElementIterator = llvm::mapped_iterator; - template ::value && - !std::is_same::value>::type> + template > llvm::iterator_range> getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getAttributeValues(), + return llvm::map_range(getValues(), static_cast(castFn)); } + template > + DerivedAttributeElementIterator value_begin() const { + auto castFn = [](Attribute attr) { return attr.template cast(); }; + return {value_begin(), static_cast(castFn)}; + } + template > + DerivedAttributeElementIterator value_end() const { + auto castFn = [](Attribute attr) { return attr.template cast(); }; + return {value_end(), static_cast(castFn)}; + } /// Return the held element values as a range of bool. The element type of /// this attribute must be of integer type of bitwidth 1. - llvm::iterator_range getBoolValues() const; - template ::value>::type> + template + using BoolValueTemplateCheckT = + typename std::enable_if::value>::type; + template > llvm::iterator_range getValues() const { - return getBoolValues(); + assert(isValidBool() && "bool is not the value of this elements attribute"); + return {BoolElementIterator(*this, 0), + BoolElementIterator(*this, getNumElements())}; + } + template > + BoolElementIterator value_begin() const { + assert(isValidBool() && "bool is not the value of this elements attribute"); + return BoolElementIterator(*this, 0); + } + template > + BoolElementIterator value_end() const { + assert(isValidBool() && "bool is not the value of this elements attribute"); + return BoolElementIterator(*this, getNumElements()); } /// Return the held element values as a range of APInts. The element type of /// this attribute must be of integer type. - llvm::iterator_range getIntValues() const; - template ::value>::type> + template + using APIntValueTemplateCheckT = + typename std::enable_if::value>::type; + template > llvm::iterator_range getValues() const { - return getIntValues(); + assert(getElementType().isIntOrIndex() && "expected integral type"); + return {raw_int_begin(), raw_int_end()}; + } + template > + IntElementIterator value_begin() const { + assert(getElementType().isIntOrIndex() && "expected integral type"); + return raw_int_begin(); + } + template > + IntElementIterator value_end() const { + assert(getElementType().isIntOrIndex() && "expected integral type"); + return raw_int_end(); } - IntElementIterator int_value_begin() const; - IntElementIterator int_value_end() const; /// Return the held element values as a range of complex APInts. The element /// type of this attribute must be a complex of integer type. - llvm::iterator_range getComplexIntValues() const; - template >::value>::type> + template + using ComplexAPIntValueTemplateCheckT = typename std::enable_if< + std::is_same>::value>::type; + template > llvm::iterator_range getValues() const { return getComplexIntValues(); } + template > + ComplexIntElementIterator value_begin() const { + return complex_value_begin(); + } + template > + ComplexIntElementIterator value_end() const { + return complex_value_end(); + } /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. - llvm::iterator_range getFloatValues() const; - template ::value>::type> + template + using APFloatValueTemplateCheckT = + typename std::enable_if::value>::type; + template > llvm::iterator_range getValues() const { return getFloatValues(); } - FloatElementIterator float_value_begin() const; - FloatElementIterator float_value_end() const; + template > + FloatElementIterator value_begin() const { + return float_value_begin(); + } + template > + FloatElementIterator value_end() const { + return float_value_end(); + } /// Return the held element values as a range of complex APFloat. The element /// type of this attribute must be a complex of float type. - llvm::iterator_range - getComplexFloatValues() const; - template >::value>::type> + template + using ComplexAPFloatValueTemplateCheckT = typename std::enable_if< + std::is_same>::value>::type; + template > llvm::iterator_range getValues() const { return getComplexFloatValues(); } + template > + ComplexFloatElementIterator value_begin() const { + return complex_float_value_begin(); + } + template > + ComplexFloatElementIterator value_end() const { + return complex_float_value_end(); + } /// Return the raw storage data held by this attribute. Users should generally /// not use this directly, as the internal storage format is not always in the @@ -590,13 +706,25 @@ function_ref mapping) const; protected: - /// Get iterators to the raw APInt values for each element in this attribute. + /// Iterators to various elements that require out-of-line definition. These + /// are hidden from the user to encourage consistent use of the + /// getValues/value_begin/value_end API. IntElementIterator raw_int_begin() const { return IntElementIterator(*this, 0); } IntElementIterator raw_int_end() const { return IntElementIterator(*this, getNumElements()); } + llvm::iterator_range getComplexIntValues() const; + ComplexIntElementIterator complex_value_begin() const; + ComplexIntElementIterator complex_value_end() const; + llvm::iterator_range getFloatValues() const; + FloatElementIterator float_value_begin() const; + FloatElementIterator float_value_end() const; + llvm::iterator_range + getComplexFloatValues() const; + ComplexFloatElementIterator complex_float_value_begin() const; + ComplexFloatElementIterator complex_float_value_end() const; /// Overload of the raw 'get' method that asserts that the given type is of /// complex type. This method is used to verify type invariants that the @@ -616,11 +744,8 @@ /// Check the information for a C++ data type, check if this type is valid for /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. + bool isValidBool() const { return getElementType().isInteger(1); } bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; - - /// Check the information for a C++ data type, check if this type is valid for - /// the current attribute. This method is used to verify specific type - /// invariants that the templatized 'getValues' method cannot. bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; }; @@ -806,7 +931,7 @@ auto SparseElementsAttr::getValues() const -> llvm::iterator_range> { auto zeroValue = getZeroValue(); - auto valueIt = getValues().getValues().begin(); + auto valueIt = getValues().value_begin(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); std::function mapFn = [flatSparseIndices{std::move(flatSparseIndices)}, @@ -821,6 +946,14 @@ }; return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); } +template +auto SparseElementsAttr::value_begin() const -> iterator { + return getValues().begin(); +} +template +auto SparseElementsAttr::value_end() const -> iterator { + return getValues().end(); +} namespace detail { /// This class represents a general iterator over the values of an ElementsAttr. @@ -833,8 +966,7 @@ // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype' // inside of a conversion operator. using DenseIteratorT = typename std::enable_if< - true, - decltype(std::declval().getValues().begin())>::type; + true, decltype(std::declval().value_begin())>::type; using SparseIteratorT = SparseElementsAttr::iterator; /// A union containing the specific iterators for each derived attribute kind. @@ -960,6 +1092,21 @@ llvm_unreachable("unexpected attribute kind"); } +template auto ElementsAttr::value_begin() const -> iterator { + if (DenseElementsAttr denseAttr = dyn_cast()) + return iterator(*this, denseAttr.value_begin()); + if (SparseElementsAttr sparseAttr = dyn_cast()) + return iterator(*this, sparseAttr.value_begin()); + llvm_unreachable("unexpected attribute kind"); +} +template auto ElementsAttr::value_end() const -> iterator { + if (DenseElementsAttr denseAttr = dyn_cast()) + return iterator(*this, denseAttr.value_end()); + if (SparseElementsAttr sparseAttr = dyn_cast()) + return iterator(*this, sparseAttr.value_end()); + llvm_unreachable("unexpected attribute kind"); +} + } // end namespace mlir. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -721,6 +721,8 @@ /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, /// etc. template llvm::iterator_range> getValues() const; + template iterator value_begin() const; + template iterator value_end() const; /// Return the value of the element at the given index. The 'index' is /// expected to refer to a valid element. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -505,48 +505,36 @@ // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - *(unwrap(attr).cast().getValues().begin() + - pos)); + unwrap(attr).cast().getFlatValue(pos)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -127,7 +127,7 @@ if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) return failure(); - if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1), + if (llvm::any_of(llvm::drop_begin(localSize.getValues(), 1), [](const APInt &size) { return !size.isOneValue(); })) return failure(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -558,9 +558,9 @@ if (srcElemType != dstElemType) { SmallVector elements; if (srcElemType.isa()) { - for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { - FloatAttr dstAttr = convertFloatAttr( - srcAttr.cast(), dstElemType.cast(), rewriter); + for (FloatAttr srcAttr : dstElementsAttr.getValues()) { + FloatAttr dstAttr = + convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -568,10 +568,9 @@ } else if (srcElemType.isInteger(1)) { return failure(); } else { - for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { - IntegerAttr dstAttr = - convertIntegerAttr(srcAttr.cast(), - dstElemType.cast(), rewriter); + for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { + IntegerAttr dstAttr = convertIntegerAttr( + srcAttr, dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1610,7 +1610,7 @@ SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); - for (auto permutation : llvm::enumerate(perms.getIntValues())) { + for (auto permutation : llvm::enumerate(perms.getValues())) { inputExprs[permutation.value().getZExtValue()] = rewriter.getAffineDimExpr(permutation.index()); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -337,11 +337,12 @@ auto attrName = OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(); auto sizeAttr = op->template getAttrOfType(attrName); + + // Async dependencies is the only variadic operand. if (!sizeAttr) - return; // Async dependencies is the only variadic operand. - SmallVector sizes; - for (auto size : sizeAttr.getIntValues()) - sizes.push_back(size.getSExtValue()); + return; + + SmallVector sizes(sizeAttr.getValues()); ++sizes.front(); op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes)); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1825,8 +1825,9 @@ // and hence was replaced. if (complexElementType.isa()) { bool isSigned = !complexElementType.isUnsignedInteger(); + auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - auto complexValue = *(attr.getComplexIntValues().begin() + index); + auto complexValue = *(valueIt + index); os << "("; printDenseIntElement(complexValue.real(), os, isSigned); os << ","; @@ -1834,8 +1835,9 @@ os << ")"; }); } else { + auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - auto complexValue = *(attr.getComplexFloatValues().begin() + index); + auto complexValue = *(valueIt + index); os << "("; printFloatValue(complexValue.real(), os); os << ","; @@ -1845,15 +1847,15 @@ } } else if (elementType.isIntOrIndex()) { bool isSigned = !elementType.isUnsignedInteger(); - auto intValues = attr.getIntValues(); + auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - printDenseIntElement(*(intValues.begin() + index), os, isSigned); + printDenseIntElement(*(valueIt + index), os, isSigned); }); } else { assert(elementType.isa() && "unexpected element type"); - auto floatValues = attr.getFloatValues(); + auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - printFloatValue(*(floatValues.begin() + index), os); + printFloatValue(*(valueIt + index), os); }); } } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -390,6 +390,8 @@ return Attribute::getType().cast(); } +Type ElementsAttr::getElementType() const { return getType().getElementType(); } + /// Returns the number of elements held by this attribute. int64_t ElementsAttr::getNumElements() const { return getType().getNumElements(); @@ -635,7 +637,7 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = getFromOpaquePointer(base).cast(); - Type eltTy = owner.getType().getElementType(); + Type eltTy = owner.getElementType(); if (auto intEltTy = eltTy.dyn_cast()) return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); if (eltTy.isa()) @@ -690,7 +692,7 @@ DenseElementsAttr attr, size_t dataIndex) : DenseElementIndexedIteratorImpl( attr.getRawData().data(), attr.isSplat(), dataIndex), - bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} + bitWidth(getDenseElementBitWidth(attr.getElementType())) {} APInt DenseElementsAttr::IntElementIterator::operator*() const { return readBits(getData(), @@ -707,7 +709,7 @@ std::complex, std::complex, std::complex>( attr.getRawData().data(), attr.isSplat(), dataIndex) { - auto complexType = attr.getType().getElementType().cast(); + auto complexType = attr.getElementType().cast(); bitWidth = getDenseElementBitWidth(complexType.getElementType()); } @@ -930,21 +932,15 @@ isInt, isSigned); } -/// A method used to verify specific type invariants that the templatized 'get' -/// method cannot. bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const { - return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, - isSigned); + return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); } - -/// Check the information for a C++ data type, check if this type is valid for -/// the current attribute. bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const { return ::isValidIntOrFloat( - getType().getElementType().cast().getElementType(), - dataEltSize / 2, isInt, isSigned); + getElementType().cast().getElementType(), dataEltSize / 2, + isInt, isSigned); } /// Returns true if this attribute corresponds to a splat, i.e. if all element @@ -953,76 +949,69 @@ return static_cast(impl)->isSplat; } -/// Return the held element values as a range of Attributes. -auto DenseElementsAttr::getAttributeValues() const - -> llvm::iterator_range { - return {attr_value_begin(), attr_value_end()}; -} -auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { - return AttributeElementIterator(*this, 0); -} -auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { - return AttributeElementIterator(*this, getNumElements()); +/// Return if the given complex type has an integer element type. +static bool isComplexOfIntType(Type type) { + return type.cast().getElementType().isa(); } -/// Return the held element values as a range of bool. The element type of -/// this attribute must be of integer type of bitwidth 1. -auto DenseElementsAttr::getBoolValues() const - -> llvm::iterator_range { - auto eltType = getType().getElementType().dyn_cast(); - assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); - (void)eltType; - return {BoolElementIterator(*this, 0), - BoolElementIterator(*this, getNumElements())}; -} - -/// Return the held element values as a range of APInts. The element type of -/// this attribute must be of integer type. -auto DenseElementsAttr::getIntValues() const - -> llvm::iterator_range { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return {raw_int_begin(), raw_int_end()}; -} -auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_begin(); -} -auto DenseElementsAttr::int_value_end() const -> IntElementIterator { - assert(getType().getElementType().isIntOrIndex() && "expected integral type"); - return raw_int_end(); -} auto DenseElementsAttr::getComplexIntValues() const -> llvm::iterator_range { - Type eltTy = getType().getElementType().cast().getElementType(); - (void)eltTy; - assert(eltTy.isa() && "expected complex integral type"); + assert(isComplexOfIntType(getElementType()) && + "expected complex integral type"); return {ComplexIntElementIterator(*this, 0), ComplexIntElementIterator(*this, getNumElements())}; } +auto DenseElementsAttr::complex_value_begin() const + -> ComplexIntElementIterator { + assert(isComplexOfIntType(getElementType()) && + "expected complex integral type"); + return ComplexIntElementIterator(*this, 0); +} +auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { + assert(isComplexOfIntType(getElementType()) && + "expected complex integral type"); + return ComplexIntElementIterator(*this, getNumElements()); +} /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. auto DenseElementsAttr::getFloatValues() const -> llvm::iterator_range { - auto elementType = getType().getElementType().cast(); + auto elementType = getElementType().cast(); const auto &elementSemantics = elementType.getFloatSemantics(); return {FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; } auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { - return getFloatValues().begin(); + auto elementType = getElementType().cast(); + return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); } auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { - return getFloatValues().end(); + auto elementType = getElementType().cast(); + return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); } + auto DenseElementsAttr::getComplexFloatValues() const -> llvm::iterator_range { - Type eltTy = getType().getElementType().cast().getElementType(); + Type eltTy = getElementType().cast().getElementType(); assert(eltTy.isa() && "expected complex float type"); const auto &semantics = eltTy.cast().getFloatSemantics(); return {{semantics, {*this, 0}}, {semantics, {*this, static_cast(getNumElements())}}}; } +auto DenseElementsAttr::complex_float_value_begin() const + -> ComplexFloatElementIterator { + Type eltTy = getElementType().cast().getElementType(); + assert(eltTy.isa() && "expected complex float type"); + return {eltTy.cast().getFloatSemantics(), {*this, 0}}; +} +auto DenseElementsAttr::complex_float_value_end() const + -> ComplexFloatElementIterator { + Type eltTy = getElementType().cast().getElementType(); + assert(eltTy.isa() && "expected complex float type"); + return {eltTy.cast().getFloatSemantics(), + {*this, static_cast(getNumElements())}}; +} /// Return the raw storage data held by this attribute. ArrayRef DenseElementsAttr::getRawData() const { @@ -1374,19 +1363,19 @@ /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { - auto eltType = getType().getElementType().cast(); + auto eltType = getElementType().cast(); return APFloat(eltType.getFloatSemantics()); } /// Get a zero APInt for the given sparse attribute. APInt SparseElementsAttr::getZeroAPInt() const { - auto eltType = getType().getElementType().cast(); + auto eltType = getElementType().cast(); return APInt::getZero(eltType.getWidth()); } /// Get a zero attribute for the given attribute type. Attribute SparseElementsAttr::getZeroAttr() const { - auto eltType = getType().getElementType(); + auto eltType = getElementType(); // Handle floating point elements. if (eltType.isa()) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1024,7 +1024,7 @@ return op->emitOpError("requires 1D i32 elements attribute '") << attrName << "'"; - if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { + if (llvm::any_of(sizeAttr.getValues(), [](const APInt &element) { return !element.isNonNegative(); })) return op->emitOpError("'") diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -51,7 +51,7 @@ auto dattr = attr.cast(); res.clear(); res.reserve(dattr.size()); - for (auto it : dattr.getIntValues()) + for (auto it : dattr.getValues()) res.push_back(it.getSExtValue()); } else { auto vals = val.get()->getDims(); @@ -71,7 +71,7 @@ return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() - .getValue({static_cast(index)}) + .getFlatValue(index) .getSExtValue(); auto *stc = val.get(); return stc->getDims()[index]; @@ -94,7 +94,7 @@ return t.cast().hasStaticShape(); if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); - for (auto index : dattr.getIntValues()) + for (auto index : dattr.getValues()) if (ShapedType::isDynamic(index.getSExtValue())) return false; return true; @@ -115,7 +115,7 @@ if (auto attr = val.dyn_cast()) { auto dattr = attr.cast(); int64_t num = 1; - for (auto index : dattr.getIntValues()) { + for (auto index : dattr.getValues()) { num *= index.getZExtValue(); assert(num >= 0 && "integer overflow in element count computation"); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -294,7 +294,8 @@ if (!nested) return nullptr; - values.append(nested.attr_value_begin(), nested.attr_value_end()); + values.append(nested.value_begin(), + nested.value_end()); } return DenseElementsAttr::get(outerType, values); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -83,12 +83,14 @@ auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>(); )"; const char *attrSizedSegmentValueRangeCalcCode = R"( - auto sizeAttrValues = sizeAttr.getValues(); + const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); + if (sizeAttr.isSplat()) + return {*sizeAttrValueIt * index, *sizeAttrValueIt}; + unsigned start = 0; for (unsigned i = 0; i < index; ++i) - start += *(sizeAttrValues.begin() + i); - unsigned size = *(sizeAttrValues.begin() + index); - return {start, size}; + start += sizeAttrValueIt[i]; + return {start, sizeAttrValueIt[index]}; )"; // The logic to calculate the actual value range for a declared operand // of an op with variadic of variadic operands within the OpAdaptor. diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -158,7 +158,7 @@ auto denseAttr = returnedAttr.dyn_cast(); ASSERT_TRUE(denseAttr); - for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) { + for (const auto &valIndexIt : llvm::enumerate(denseAttr.getValues())) { EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1); } }