diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -227,6 +227,33 @@ ElementsAttrIndexer indexer; ptrdiff_t index; }; + +/// This class provides iterator utilities for an ElementsAttr range. +template +class ElementsAttrRange : public llvm::iterator_range { +public: + using reference = typename IteratorT::reference; + + ElementsAttrRange(Type shapeType, + const llvm::iterator_range &range) + : llvm::iterator_range(range), shapeType(shapeType) {} + ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt) + : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} + + /// Return the value at the given index. + reference operator[](ArrayRef index) const; + reference operator[](uint64_t index) const { + return *std::next(this->begin(), index); + } + + /// Return the size of this range. + size_t size() const { return llvm::size(*this); } + +private: + /// The shaped type of the parent ElementsAttr. + Type shapeType; +}; + } // namespace detail //===----------------------------------------------------------------------===// @@ -256,6 +283,16 @@ //===----------------------------------------------------------------------===// namespace mlir { +namespace detail { +/// Return the value at the given index. +template +auto ElementsAttrRange::operator[](ArrayRef index) const + -> reference { + // Skip to the element corresponding to the flattened index. + return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)]; +} +} // namespace detail + /// Return the elements of this attribute as a value of type 'T'. template auto ElementsAttr::value_begin() const -> DefaultValueCheckT> { diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -158,27 +158,6 @@ ]; string ElementsAttrInterfaceAccessors = [{ - /// Return the attribute value at the given index. The index is expected to - /// refer to a valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - - /// Return the value of type 'T' at the given index, where 'T' corresponds - /// to an Attribute type. - template - std::enable_if_t::value && - std::is_base_of::value> - getValue(ArrayRef index) const { - return getValue(index).template dyn_cast_or_null(); - } - - /// Return the value of type 'T' at the given index. - template - T getValue(ArrayRef index) const { - return getFlatValue(getFlattenedIndex(index)); - } - /// Return the number of elements held by this attribute. int64_t size() const { return getNumElements(); } @@ -294,17 +273,13 @@ auto beginIt = $_attr.template value_begin(); return llvm::make_range(beginIt, std::next(beginIt, size())); } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next($_attr.template value_begin(), index); - } }] # ElementsAttrInterfaceAccessors; let extraClassDeclaration = [{ template using iterator = detail::ElementsAttrIterator; template - using iterator_range = llvm::iterator_range>; + using iterator_range = detail::ElementsAttrRange>; //===------------------------------------------------------------------===// // Accessors @@ -329,8 +304,12 @@ uint64_t getFlattenedIndex(ArrayRef index) const { return getFlattenedIndex(*this, index); } - static uint64_t getFlattenedIndex(Attribute elementsAttr, + static uint64_t getFlattenedIndex(Type type, ArrayRef index); + static uint64_t getFlattenedIndex(Attribute elementsAttr, + ArrayRef index) { + return getFlattenedIndex(elementsAttr.getType(), index); + } /// Returns the number of elements held by this attribute. int64_t getNumElements() const { return getNumElements(*this); } @@ -350,13 +329,6 @@ !std::is_base_of::value, ResultT>; - /// Return the element of this attribute at the given index as a value of - /// type 'T'. - template - T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); - } - /// Return the splat value for this attribute. This asserts that the /// attribute corresponds to a splat. template @@ -368,7 +340,7 @@ /// Return the elements of this attribute as a value of type 'T'. template DefaultValueCheckT> getValues() const { - return iterator_range(value_begin(), value_end()); + return {Attribute::getType(), value_begin(), value_end()}; } template DefaultValueCheckT> value_begin() const; @@ -384,12 +356,12 @@ llvm::mapped_iterator, T (*)(Attribute)>; template using DerivedAttrValueIteratorRange = - llvm::iterator_range>; + detail::ElementsAttrRange>; template > DerivedAttrValueIteratorRange getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttrValueIterator value_begin() const { @@ -407,8 +379,10 @@ /// return the iterable range. Otherwise, return llvm::None. template DefaultValueCheckT>> tryGetValues() const { - if (Optional> beginIt = try_value_begin()) - return iterator_range(*beginIt, value_end()); + if (Optional> beginIt = try_value_begin()) { + return iterator_range(Attribute::getType(), *beginIt, + value_end()); + } return llvm::None; } template @@ -418,10 +392,15 @@ /// return the iterable range. Otherwise, return llvm::None. template > Optional> tryGetValues() const { + auto values = tryGetValues(); + if (!values) + return llvm::None; + auto castFn = [](Attribute attr) { return attr.template cast(); }; - if (auto values = tryGetValues()) - return llvm::map_range(*values, static_cast(castFn)); - return llvm::None; + return DerivedAttrValueIteratorRange( + Attribute::getType(), + llvm::map_range(*values, static_cast(castFn)) + ); } template > Optional> try_value_begin() const { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -219,6 +219,18 @@ // Iterators //===--------------------------------------------------------------------===// + /// The iterator range over the given iterator type T. + template + using iterator_range_impl = detail::ElementsAttrRange; + + /// The iterator for the given element type T. + template + using iterator = decltype(std::declval().template value_begin()); + /// The iterator range over the given element T. + template + using iterator_range = + decltype(std::declval().template getValues()); + /// A utility iterator that allows walking over the internal Attribute values /// of a DenseElementsAttr. class AttributeElementIterator @@ -343,7 +355,6 @@ /// Return the splat value for this attribute. This asserts that the attribute /// corresponds to a splat. - Attribute getSplatValue() const { return getSplatValue(); } template typename std::enable_if::value || std::is_same::value, @@ -358,22 +369,7 @@ !std::is_same::value, T>::type getSplatValue() const { - return getSplatValue().template cast(); - } - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - template - T getValue(ArrayRef index) const { - // Skip to the element corresponding to the flattened index. - return getFlatValue(ElementsAttr::getFlattenedIndex(*this, index)); - } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); + return getSplatValue().template cast(); } /// Return the held element values as a range of integer or floating-point @@ -384,12 +380,12 @@ std::numeric_limits::is_integer) || is_valid_cpp_fp_type::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template > @@ -413,12 +409,12 @@ is_valid_cpp_fp_type::value)>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template ::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto stringRefs = getRawStringData(); const char *ptr = reinterpret_cast(stringRefs.data()); bool splat = isSplat(); - return {ElementIterator(ptr, splat, 0), + return {Attribute::getType(), ElementIterator(ptr, splat, 0), ElementIterator(ptr, splat, getNumElements())}; } template > @@ -464,8 +460,9 @@ using AttributeValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { - return {value_begin(), value_end()}; + iterator_range_impl getValues() const { + return {Attribute::getType(), value_begin(), + value_end()}; } template > AttributeElementIterator value_begin() const { @@ -486,10 +483,11 @@ using DerivedAttributeElementIterator = llvm::mapped_iterator; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), + llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttributeElementIterator value_begin() const { @@ -508,9 +506,9 @@ using BoolValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(isValidBool() && "bool is not the value of this elements attribute"); - return {BoolElementIterator(*this, 0), + return {Attribute::getType(), BoolElementIterator(*this, 0), BoolElementIterator(*this, getNumElements())}; } template > @@ -530,9 +528,9 @@ using APIntValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(getElementType().isIntOrIndex() && "expected integral type"); - return {raw_int_begin(), raw_int_end()}; + return {Attribute::getType(), raw_int_begin(), raw_int_end()}; } template > IntElementIterator value_begin() const { @@ -551,7 +549,7 @@ using ComplexAPIntValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexIntValues(); } template > @@ -569,7 +567,7 @@ using APFloatValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getFloatValues(); } template > @@ -587,7 +585,7 @@ using ComplexAPFloatValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexFloatValues(); } template > @@ -660,13 +658,13 @@ IntElementIterator raw_int_end() const { return IntElementIterator(*this, getNumElements()); } - llvm::iterator_range getComplexIntValues() const; + iterator_range_impl getComplexIntValues() const; ComplexIntElementIterator complex_value_begin() const; ComplexIntElementIterator complex_value_end() const; - llvm::iterator_range getFloatValues() const; + iterator_range_impl getFloatValues() const; FloatElementIterator float_value_begin() const; FloatElementIterator float_value_end() const; - llvm::iterator_range + iterator_range_impl getComplexFloatValues() const; ComplexFloatElementIterator complex_float_value_begin() const; ComplexFloatElementIterator complex_float_value_end() const; @@ -872,8 +870,7 @@ //===----------------------------------------------------------------------===// template -auto SparseElementsAttr::getValues() const - -> llvm::iterator_range> { +auto SparseElementsAttr::value_begin() const -> iterator { auto zeroValue = getZeroValue(); auto valueIt = getValues().value_begin(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); @@ -888,15 +885,7 @@ // Otherwise, return the zero value. return zeroValue; }; - return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); -} -template -auto SparseElementsAttr::value_begin() const -> iterator { - return getValues().begin(); -} -template -auto SparseElementsAttr::value_end() const -> iterator { - return getValues().end(); + return iterator(llvm::seq(0, getNumElements()).begin(), mapFn); } } // end namespace mlir. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -173,9 +173,7 @@ "ArrayRef":$rawData); let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -312,9 +310,7 @@ ]; let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -706,10 +702,6 @@ let extraClassDeclaration = [{ using ValueType = StringRef; - /// Return the value at the given index. The 'index' is expected to refer to - /// a valid element. - Attribute getValue(ArrayRef index) const; - /// Decodes the attribute value using dialect-specific decoding hook. /// Returns false if decoding is successful. If not, returns true and leaves /// 'result' argument unspecified. @@ -811,13 +803,7 @@ /// Return the values of this attribute in the form of the given type 'T'. /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, /// etc. - template llvm::iterator_range> getValues() const; template iterator value_begin() const; - template iterator value_end() const; - - /// Return the value of the element at the given index. The 'index' is - /// expected to refer to a valid element. - Attribute getValue(ArrayRef index) const; private: /// Get a zero APFloat for the given sparse attribute. diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -110,7 +110,7 @@ if (type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) - .match(splatAttr.getSplatValue()); + .match(splatAttr.getSplatValue()); } } return false; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -288,8 +288,9 @@ MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr).cast().getValue( - llvm::makeArrayRef(idxs, rank))); + return wrap(unwrap(attr) + .cast() + .getValues()[llvm::makeArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, @@ -482,7 +483,8 @@ } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getSplatValue()); + return wrap( + unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); @@ -520,36 +522,36 @@ // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getFlatValue(pos)); + unwrap(attr).cast().getValues()[pos]); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,7 +169,7 @@ return failure(); auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); - auto val = workGroupSizeAttr.getValue(index.getValue()); + auto val = workGroupSizeAttr.getValues()[index.getValue()]; auto convertedType = getTypeConverter()->convertType(op.getResult().getType()); if (!convertedType) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -451,7 +451,7 @@ // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) - initialValue = elementsAttr.getValue({}); + initialValue = elementsAttr.getSplatValue(); } uint64_t alignment = global.alignment().getValueOr(0); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -329,7 +329,8 @@ llvm::DenseMap &valueMapping) { assert(constantSupportsMMAMatrixType(op)); OpBuilder b(op); - Attribute splat = op.getValue().cast().getSplatValue(); + Attribute splat = + op.getValue().cast().getSplatValue(); auto scalarConstant = b.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2890,19 +2890,19 @@ } AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { + auto values = lowerBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += lowerBoundsGroups().getValue(i); - return lowerBoundsMap().getSliceMap( - start, lowerBoundsGroups().getValue(pos)); + start += values[i]; + return lowerBoundsMap().getSliceMap(start, values[pos]); } AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { + auto values = upperBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += upperBoundsGroups().getValue(i); - return upperBoundsMap().getSliceMap( - start, upperBoundsGroups().getValue(pos)); + start += values[i]; + return upperBoundsMap().getSliceMap(start, values[pos]); } AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1573,7 +1573,7 @@ if (auto fpValue = value.dyn_cast()) return fpValue.getValue().isZero(); if (auto splatValue = value.dyn_cast()) - return isZeroAttribute(splatValue.getSplatValue()); + return isZeroAttribute(splatValue.getSplatValue()); if (auto elementsValue = value.dyn_cast()) return llvm::all_of(elementsValue.getValues(), isZeroAttribute); if (auto arrayValue = value.dyn_cast()) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -163,8 +163,8 @@ /// Returns the value that corresponds to named position `pos` from the /// attribute `attr` assuming it's a dense integer elements attribute. static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) { - return attr.cast().getValue( - static_cast(pos)); + return attr.cast() + .getValues()[static_cast(pos)]; } /// Returns the part of the data layout entry that corresponds to `pos` for the diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1184,7 +1184,7 @@ if (matchPattern(def, m_Constant(&splatAttr)) && splatAttr.isSplat() && splatAttr.getType().getElementType().isIntOrFloat()) { - constantAttr = splatAttr.getSplatValue(); + constantAttr = splatAttr.getSplatValue(); return true; } } @@ -1455,10 +1455,9 @@ bool isFloat = elementType.isa(); if (isFloat) { - SmallVector> - inputFpIterators; + SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) - inputFpIterators.push_back(inputValues[i].getValues()); + inFpRanges.push_back(inputValues[i].getValues()); computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); @@ -1469,22 +1468,17 @@ computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apFloats[i] = - *(inputFpIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue(); + fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; } } else { - SmallVector> - inputIntIterators; + SmallVector> inIntRanges; for (int i = 0; i < numInputs; ++i) - inputIntIterators.push_back(inputValues[i].getValues()); + inIntRanges.push_back(inputValues[i].getValues()); computeFnInputs.apInts.resize(numInputs); @@ -1495,25 +1489,19 @@ computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apInts[i] = - *(inputIntIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - intOutputValues[dstLinearIndex] = outputs.apInt.getValue(); + intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; } } - DenseIntOrFPElementsAttr outputAttr; - if (isFloat) { - outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); - } else { - outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); - } + DenseElementsAttr outputAttr = + isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) + : DenseElementsAttr::get(outputType, intOutputValues); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -56,7 +56,7 @@ if (auto vector = composite.dyn_cast()) { assert(indices.size() == 1 && "must have exactly one index for a vector"); - return vector.getValue({indices[0]}); + return vector.getValues()[indices[0]]; } if (auto array = composite.dyn_cast()) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1138,7 +1138,7 @@ return nullptr; if (dim.getValue() >= elements.getNumElements()) return nullptr; - return elements.getValue({(uint64_t)dim.getValue()}); + return elements.getValues()[(uint64_t)dim.getValue()]; } void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1304,13 +1304,14 @@ if (!caseValues) return; - for (int64_t i = 0, size = caseValues.size(); i < size; ++i) { + for (const auto &it : llvm::enumerate(caseValues.getValues())) { p << ','; p.printNewline(); p << " "; - p << caseValues.getValue(i).getLimitedValue(); + p << it.value().getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]); + p.printSuccessorAndUseList(caseDestinations[it.index()], + caseOperands[it.index()]); } p.printNewline(); } @@ -1353,9 +1354,9 @@ SuccessorRange caseDests = getCaseDestinations(); if (auto value = operands.front().dyn_cast_or_null()) { - for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i) - if (value == caseValues->getValue(i)) - return caseDests[i]; + for (const auto &it : llvm::enumerate(caseValues->getValues())) + if (it.value() == value.getValue()) + return caseDests[it.index()]; return getDefaultDestination(); } return nullptr; @@ -1394,15 +1395,15 @@ auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseDests[i] == op.getDefaultDestination() && - op.getCaseOperands(i) == op.getDefaultOperands()) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseDests[it.index()] == op.getDefaultDestination() && + op.getCaseOperands(it.index()) == op.getDefaultOperands()) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) @@ -1424,10 +1425,11 @@ static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, APInt caseValue) { auto caseValues = op.getCaseValues(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValues->getValue(i) == caseValue) { - rewriter.replaceOpWithNewOp(op, op.getCaseDestinations()[i], - op.getCaseOperands(i)); + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (it.value() == caseValue) { + rewriter.replaceOpWithNewOp( + op, op.getCaseDestinations()[it.index()], + op.getCaseOperands(it.index())); return; } } @@ -1551,22 +1553,16 @@ return failure(); // Fold this switch to an unconditional branch. - APInt caseValue; - bool isDefault = true; SuccessorRange predDests = predSwitch.getCaseDestinations(); - Optional predCaseValues = predSwitch.getCaseValues(); - for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { - if (currentBlock == predDests[i]) { - caseValue = predCaseValues->getValue(i); - isDefault = false; - break; - } - } - if (isDefault) + auto it = llvm::find(predDests, currentBlock); + if (it != predDests.end()) { + Optional predCaseValues = predSwitch.getCaseValues(); + foldSwitch(op, rewriter, + predCaseValues->getValues()[it - predDests.begin()]); + } else { rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), op.getDefaultOperands()); - else - foldSwitch(op, rewriter, caseValue); + } return success(); } @@ -1613,7 +1609,7 @@ auto predCaseValues = predSwitch.getCaseValues(); for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) if (currentBlock != predDests[i]) - caseValuesToRemove.insert(predCaseValues->getValue(i)); + caseValuesToRemove.insert(predCaseValues->getValues()[i]); SmallVector newCaseDestinations; SmallVector newCaseOperands; @@ -1622,14 +1618,14 @@ auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValuesToRemove.contains(caseValues->getValue(i))) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseValuesToRemove.contains(it.value())) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -340,7 +340,7 @@ // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) - return splatTensor.getSplatValue(); + return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; @@ -353,7 +353,7 @@ // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValue(indices); + return elementsAttr.getValues()[indices]; return {}; } @@ -440,7 +440,7 @@ Attribute dest = operands[1]; if (scalar && dest) if (auto splatDest = dest.dyn_cast()) - if (scalar == splatDest.getSplatValue()) + if (scalar == splatDest.getSplatValue()) return dest; return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -230,6 +230,7 @@ // Transpose the input constant. Because we don't know its rank in advance, // we need to loop over the range [0, element count) and delinearize the // index. + auto attrValues = inputValues.getValues(); for (int srcLinearIndex = 0; srcLinearIndex < numElements; ++srcLinearIndex) { SmallVector srcIndices(inputType.getRank(), 0); @@ -247,7 +248,7 @@ for (int dim = 1; dim < outputType.getRank(); ++dim) dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - outputValues[dstLinearIndex] = inputValues.getValue(srcIndices); + outputValues[dstLinearIndex] = attrValues[srcIndices]; } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1395,7 +1395,7 @@ if (operands[0].getType().isIntOrIndexOrFloat()) return DenseElementsAttr::get(vectorType, operands[0]); if (auto attr = operands[0].dyn_cast()) - return DenseElementsAttr::get(vectorType, attr.getSplatValue()); + return DenseElementsAttr::get(vectorType, attr.getSplatValue()); return {}; } @@ -2212,7 +2212,7 @@ if (!dense) return failure(); auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), - dense.getSplatValue()); + dense.getSplatValue()); rewriter.replaceOpWithNewOp(extractStridedSliceOp, newAttr); return success(); @@ -3670,8 +3670,9 @@ auto dense = constantOp.getValue().dyn_cast(); if (!dense) return failure(); - auto newAttr = DenseElementsAttr::get( - shapeCastOp.getType().cast(), dense.getSplatValue()); + auto newAttr = + DenseElementsAttr::get(shapeCastOp.getType().cast(), + dense.getSplatValue()); rewriter.replaceOpWithNewOp(shapeCastOp, newAttr); return success(); } diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -56,15 +56,15 @@ return isValidIndex(elementsAttr.getType().cast(), index); } -uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr, - ArrayRef index) { - ShapedType type = elementsAttr.getType().cast(); - assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); +uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { + ShapedType shapeType = type.cast(); + assert(isValidIndex(shapeType, index) && + "expected valid multi-dimensional index"); // Reduce the provided multidimensional index into a flattended 1D row-major // index. - auto rank = type.getRank(); - auto shape = type.getShape(); + auto rank = shapeType.getRank(); + ArrayRef shape = shapeType.getShape(); uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; for (int i = rank - 1; i >= 0; --i) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -878,10 +878,10 @@ } auto DenseElementsAttr::getComplexIntValues() const - -> llvm::iterator_range { + -> iterator_range_impl { assert(isComplexOfIntType(getElementType()) && "expected complex integral type"); - return {ComplexIntElementIterator(*this, 0), + return {getType(), ComplexIntElementIterator(*this, 0), ComplexIntElementIterator(*this, getNumElements())}; } auto DenseElementsAttr::complex_value_begin() const @@ -899,10 +899,10 @@ /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. auto DenseElementsAttr::getFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { auto elementType = getElementType().cast(); const auto &elementSemantics = elementType.getFloatSemantics(); - return {FloatElementIterator(elementSemantics, raw_int_begin()), + return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; } auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { @@ -915,11 +915,12 @@ } auto DenseElementsAttr::getComplexFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { Type eltTy = getElementType().cast().getElementType(); assert(eltTy.isa() && "expected complex float type"); const auto &semantics = eltTy.cast().getFloatSemantics(); - return {{semantics, {*this, 0}}, + return {getType(), + {semantics, {*this, 0}}, {semantics, {*this, static_cast(getNumElements())}}}; } auto DenseElementsAttr::complex_float_value_begin() const @@ -1224,13 +1225,6 @@ // OpaqueElementsAttr //===----------------------------------------------------------------------===// -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - return Attribute(); -} - bool OpaqueElementsAttr::decode(ElementsAttr &result) { Dialect *dialect = getDialect().getDialect(); if (!dialect) @@ -1255,47 +1249,6 @@ // SparseElementsAttr //===----------------------------------------------------------------------===// -/// Return the value of the element at the given index. -Attribute SparseElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // The sparse indices are 64-bit integers, so we can reinterpret the raw data - // as a 1-D index array. - auto sparseIndices = getIndices(); - auto sparseIndexValues = sparseIndices.getValues(); - - // Check to see if the indices are a splat. - if (sparseIndices.isSplat()) { - // If the index is also not a splat of the index value, we know that the - // value is zero. - auto splatIndex = *sparseIndexValues.begin(); - if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) - return getZeroAttr(); - - // If the indices are a splat, we also expect the values to be a splat. - assert(getValues().isSplat() && "expected splat values"); - return getValues().getSplatValue(); - } - - // Build a mapping between known indices and the offset of the stored element. - llvm::SmallDenseMap, size_t> mappedIndices; - auto numSparseIndices = sparseIndices.getType().getDimSize(0); - size_t rank = type.getRank(); - for (size_t i = 0, e = numSparseIndices; i != e; ++i) - mappedIndices.try_emplace( - {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); - - // Look for the provided index key within the mapped indices. If the provided - // index is not found, then return a zero attribute. - auto it = mappedIndices.find(index); - if (it == mappedIndices.end()) - return getZeroAttr(); - - // Otherwise, return the held sparse value element. - return getValues().getValue(it->second); -} - /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { auto eltType = getElementType().cast(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -71,7 +71,7 @@ return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() - .getFlatValue(index) + .getValues()[index] .getSExtValue(); auto *stc = val.get(); return stc->getDims()[index]; diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -384,14 +384,12 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.getBranchWeights(); llvm::MDNode *branchWeights = nullptr; - if (weights) { + if (auto weights = condbrOp.getBranchWeights()) { // Map weight attributes to LLVM metadata. - auto trueWeight = - weights.getValue().getValue(0).cast().getInt(); - auto falseWeight = - weights.getValue().getValue(1).cast().getInt(); + auto weightValues = weights->getValues(); + auto trueWeight = weightValues[0].getSExtValue(); + auto falseWeight = weightValues[1].getSExtValue(); branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) .createBranchWeights(static_cast(trueWeight), diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -139,7 +139,7 @@ if (denseElementsAttr.isSplat() && (type.isa() || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( - innermostLLVMType, denseElementsAttr.getSplatValue(), loc, + innermostLLVMType, denseElementsAttr.getSplatValue(), loc, moduleTranslation, /*isTopLevel=*/false); llvm::Constant *splatVector = llvm::ConstantDataVector::getSplat(0, splatValue); @@ -254,8 +254,9 @@ isa(elementType); llvm::Constant *child = getLLVMConstant( elementType, - elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc, - moduleTranslation, false); + elementTypeSequential ? splatAttr + : splatAttr.getSplatValue(), + loc, moduleTranslation, false); if (!child) return nullptr; if (llvmType->isVectorTy()) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -706,11 +706,12 @@ if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { return attr.getType().getElementType().isInteger(1) - ? prepareConstantBool(loc, attr.getValue(index)) - : prepareConstantInt(loc, attr.getValue(index)); + ? prepareConstantBool(loc, attr.getValues()[index]) + : prepareConstantInt(loc, + attr.getValues()[index]); } if (auto attr = valueAttr.dyn_cast()) { - return prepareConstantFp(loc, attr.getValue(index)); + return prepareConstantFp(loc, attr.getValues()[index]); } return 0; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -154,9 +154,9 @@ # ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes(); # IMPL: getSymbolBindings(Test2Op self) -# IMPL: cst2 = self.strides().getValue({ 0 }); +# IMPL: cst2 = self.strides().getValues()[{ 0 }]; # IMPL-NEXT: getAffineConstantExpr(cst2, context) -# IMPL: cst3 = self.strides().getValue({ 1 }); +# IMPL: cst3 = self.strides().getValues()[{ 1 }]; # IMPL-NEXT: getAffineConstantExpr(cst3, context) # IMPL: Test2Op::indexing_maps() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -726,7 +726,7 @@ // {1}: Symbol position // {2}: Attribute index static const char structuredOpAccessAttrFormat[] = R"FMT( -int64_t cst{1} = self.{0}().getValue({ {2} }); +int64_t cst{1} = self.{0}().getValues()[{ {2} }]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; // Update all symbol bindings mapped to an attribute. diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -113,7 +113,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -138,7 +139,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -162,7 +164,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -202,7 +202,7 @@ RankedTensorType shape = RankedTensorType::get({}, intTy); auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); - EXPECT_TRUE(attr.getValue({0}) == value); + EXPECT_TRUE(attr.getValues()[0] == value); } } // end namespace