diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -227,6 +227,33 @@ ElementsAttrIndexer indexer; ptrdiff_t index; }; + +/// This class provides iterator utilities for an ElementsAttr range. +template +class ElementsAttrRange : public llvm::iterator_range { +public: + using reference = typename IteratorT::reference; + + ElementsAttrRange(Type shapeType, + const llvm::iterator_range &range) + : llvm::iterator_range(range), shapeType(shapeType) {} + ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt) + : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} + + /// Return the value at the given index. + reference operator[](ArrayRef index) const; + reference operator[](uint64_t index) const { + return *std::next(this->begin(), index); + } + + /// Return the size of this range. + size_t size() const { return llvm::size(*this); } + +private: + /// The shaped type of the parent ElementsAttr. + Type shapeType; +}; + } // namespace detail //===----------------------------------------------------------------------===// @@ -256,6 +283,16 @@ //===----------------------------------------------------------------------===// namespace mlir { +namespace detail { +/// Return the value at the given index. +template +auto ElementsAttrRange::operator[](ArrayRef index) const + -> reference { + // Skip to the element corresponding to the flattened index. + return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)]; +} +} // namespace detail + /// Return the elements of this attribute as a value of type 'T'. template auto ElementsAttr::value_begin() const -> DefaultValueCheckT> { diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -158,27 +158,6 @@ ]; string ElementsAttrInterfaceAccessors = [{ - /// Return the attribute value at the given index. The index is expected to - /// refer to a valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - - /// Return the value of type 'T' at the given index, where 'T' corresponds - /// to an Attribute type. - template - std::enable_if_t::value && - std::is_base_of::value> - getValue(ArrayRef index) const { - return getValue(index).template dyn_cast_or_null(); - } - - /// Return the value of type 'T' at the given index. - template - T getValue(ArrayRef index) const { - return getFlatValue(getFlattenedIndex(index)); - } - /// Return the number of elements held by this attribute. int64_t size() const { return getNumElements(); } @@ -281,6 +260,14 @@ // Value Iteration //===------------------------------------------------------------------===// + /// The iterator for the given element type T. + template + using iterator = decltype(std::declval().template value_begin()); + /// The iterator range over the given element T. + template + using iterator_range = + decltype(std::declval().template getValues()); + /// Return an iterator to the first element of this attribute as a value of /// type `T`. template @@ -292,11 +279,8 @@ template auto getValues() const { auto beginIt = $_attr.template value_begin(); - return llvm::make_range(beginIt, std::next(beginIt, size())); - } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next($_attr.template value_begin(), index); + return detail::ElementsAttrRange( + Attribute($_attr).getType(), beginIt, std::next(beginIt, size())); } }] # ElementsAttrInterfaceAccessors; @@ -304,7 +288,7 @@ template using iterator = detail::ElementsAttrIterator; template - using iterator_range = llvm::iterator_range>; + using iterator_range = detail::ElementsAttrRange>; //===------------------------------------------------------------------===// // Accessors @@ -329,8 +313,12 @@ uint64_t getFlattenedIndex(ArrayRef index) const { return getFlattenedIndex(*this, index); } - static uint64_t getFlattenedIndex(Attribute elementsAttr, + static uint64_t getFlattenedIndex(Type type, ArrayRef index); + static uint64_t getFlattenedIndex(Attribute elementsAttr, + ArrayRef index) { + return getFlattenedIndex(elementsAttr.getType(), index); + } /// Returns the number of elements held by this attribute. int64_t getNumElements() const { return getNumElements(*this); } @@ -350,13 +338,6 @@ !std::is_base_of::value, ResultT>; - /// Return the element of this attribute at the given index as a value of - /// type 'T'. - template - T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); - } - /// Return the splat value for this attribute. This asserts that the /// attribute corresponds to a splat. template @@ -368,7 +349,7 @@ /// Return the elements of this attribute as a value of type 'T'. template DefaultValueCheckT> getValues() const { - return iterator_range(value_begin(), value_end()); + return {Attribute::getType(), value_begin(), value_end()}; } template DefaultValueCheckT> value_begin() const; @@ -384,12 +365,12 @@ llvm::mapped_iterator, T (*)(Attribute)>; template using DerivedAttrValueIteratorRange = - llvm::iterator_range>; + detail::ElementsAttrRange>; template > DerivedAttrValueIteratorRange getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttrValueIterator value_begin() const { @@ -407,8 +388,10 @@ /// return the iterable range. Otherwise, return llvm::None. template DefaultValueCheckT>> tryGetValues() const { - if (Optional> beginIt = try_value_begin()) - return iterator_range(*beginIt, value_end()); + if (Optional> beginIt = try_value_begin()) { + return iterator_range(Attribute::getType(), *beginIt, + value_end()); + } return llvm::None; } template @@ -418,10 +401,15 @@ /// return the iterable range. Otherwise, return llvm::None. template > Optional> tryGetValues() const { + auto values = tryGetValues(); + if (!values) + return llvm::None; + auto castFn = [](Attribute attr) { return attr.template cast(); }; - if (auto values = tryGetValues()) - return llvm::map_range(*values, static_cast(castFn)); - return llvm::None; + return DerivedAttrValueIteratorRange( + Attribute::getType(), + llvm::map_range(*values, static_cast(castFn)) + ); } template > Optional> try_value_begin() const { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -61,8 +61,7 @@ }; /// Type trait detector that checks if a given type T is a complex type. -template -struct is_complex_t : public std::false_type {}; +template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; } // namespace detail @@ -82,8 +81,7 @@ /// floating point type that can be used to access the underlying element /// types of a DenseElementsAttr. // TODO: Use std::disjunction when C++17 is supported. - template - struct is_valid_cpp_fp_type { + template struct is_valid_cpp_fp_type { /// The type is a valid floating point type if it is a builtin floating /// point type, or is a potentially user defined floating point type. The /// latter allows for supporting users that have custom types defined for @@ -219,6 +217,18 @@ // Iterators //===--------------------------------------------------------------------===// + /// The iterator range over the given iterator type T. + template + using iterator_range_impl = detail::ElementsAttrRange; + + /// The iterator for the given element type T. + template + using iterator = decltype(std::declval().template value_begin()); + /// The iterator range over the given element T. + template + using iterator_range = + decltype(std::declval().template getValues()); + /// A utility iterator that allows walking over the internal Attribute values /// of a DenseElementsAttr. class AttributeElementIterator @@ -358,22 +368,7 @@ !std::is_same::value, T>::type getSplatValue() const { - return getSplatValue().template cast(); - } - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - template - T getValue(ArrayRef index) const { - // Skip to the element corresponding to the flattened index. - return getFlatValue(ElementsAttr::getFlattenedIndex(*this, index)); - } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); + return getSplatValue().template cast(); } /// Return the held element values as a range of integer or floating-point @@ -384,12 +379,12 @@ std::numeric_limits::is_integer) || is_valid_cpp_fp_type::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template > @@ -413,12 +408,12 @@ is_valid_cpp_fp_type::value)>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template ::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto stringRefs = getRawStringData(); const char *ptr = reinterpret_cast(stringRefs.data()); bool splat = isSplat(); - return {ElementIterator(ptr, splat, 0), + return {Attribute::getType(), ElementIterator(ptr, splat, 0), ElementIterator(ptr, splat, getNumElements())}; } template > @@ -464,8 +459,9 @@ using AttributeValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { - return {value_begin(), value_end()}; + iterator_range_impl getValues() const { + return {Attribute::getType(), value_begin(), + value_end()}; } template > AttributeElementIterator value_begin() const { @@ -486,10 +482,11 @@ using DerivedAttributeElementIterator = llvm::mapped_iterator; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), + llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttributeElementIterator value_begin() const { @@ -508,9 +505,9 @@ using BoolValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(isValidBool() && "bool is not the value of this elements attribute"); - return {BoolElementIterator(*this, 0), + return {Attribute::getType(), BoolElementIterator(*this, 0), BoolElementIterator(*this, getNumElements())}; } template > @@ -530,9 +527,9 @@ using APIntValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(getElementType().isIntOrIndex() && "expected integral type"); - return {raw_int_begin(), raw_int_end()}; + return {Attribute::getType(), raw_int_begin(), raw_int_end()}; } template > IntElementIterator value_begin() const { @@ -551,7 +548,7 @@ using ComplexAPIntValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexIntValues(); } template > @@ -569,7 +566,7 @@ using APFloatValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getFloatValues(); } template > @@ -587,7 +584,7 @@ using ComplexAPFloatValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexFloatValues(); } template > @@ -660,13 +657,13 @@ IntElementIterator raw_int_end() const { return IntElementIterator(*this, getNumElements()); } - llvm::iterator_range getComplexIntValues() const; + iterator_range_impl getComplexIntValues() const; ComplexIntElementIterator complex_value_begin() const; ComplexIntElementIterator complex_value_end() const; - llvm::iterator_range getFloatValues() const; + iterator_range_impl getFloatValues() const; FloatElementIterator float_value_begin() const; FloatElementIterator float_value_end() const; - llvm::iterator_range + iterator_range_impl getComplexFloatValues() const; ComplexFloatElementIterator complex_float_value_begin() const; ComplexFloatElementIterator complex_float_value_end() const; @@ -872,8 +869,7 @@ //===----------------------------------------------------------------------===// template -auto SparseElementsAttr::getValues() const - -> llvm::iterator_range> { +auto SparseElementsAttr::value_begin() const -> iterator { auto zeroValue = getZeroValue(); auto valueIt = getValues().value_begin(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); @@ -888,15 +884,7 @@ // Otherwise, return the zero value. return zeroValue; }; - return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); -} -template -auto SparseElementsAttr::value_begin() const -> iterator { - return getValues().begin(); -} -template -auto SparseElementsAttr::value_end() const -> iterator { - return getValues().end(); + return iterator(llvm::seq(0, getNumElements()).begin(), mapFn); } } // end namespace mlir. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -174,9 +174,7 @@ "ArrayRef":$rawData); let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -313,9 +311,7 @@ ]; let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -712,10 +708,6 @@ let extraClassDeclaration = [{ using ValueType = StringRef; - /// Return the value at the given index. The 'index' is expected to refer to - /// a valid element. - Attribute getValue(ArrayRef index) const; - /// Decodes the attribute value using dialect-specific decoding hook. /// Returns false if decoding is successful. If not, returns true and leaves /// 'result' argument unspecified. @@ -802,6 +794,7 @@ // String types. StringRef >; + using ElementsAttr::Trait::getValues; /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. template @@ -817,13 +810,7 @@ /// Return the values of this attribute in the form of the given type 'T'. /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, /// etc. - template llvm::iterator_range> getValues() const; template iterator value_begin() const; - template iterator value_end() const; - - /// Return the value of the element at the given index. The 'index' is - /// expected to refer to a valid element. - Attribute getValue(ArrayRef index) const; private: /// Get a zero APFloat for the given sparse attribute. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -288,8 +288,9 @@ MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr).cast().getValue( - llvm::makeArrayRef(idxs, rank))); + return wrap(unwrap(attr) + .cast() + .getValues()[llvm::makeArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, @@ -482,7 +483,8 @@ } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getSplatValue()); + return wrap( + unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); @@ -520,36 +522,36 @@ // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getFlatValue(pos)); + unwrap(attr).cast().getValues()[pos]); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,7 +169,7 @@ return failure(); auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); - auto val = workGroupSizeAttr.getValue(index.getValue()); + auto val = workGroupSizeAttr.getValues()[index.getValue()]; auto convertedType = getTypeConverter()->convertType(op.getResult().getType()); if (!convertedType) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -451,7 +451,7 @@ // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) - initialValue = elementsAttr.getValue({}); + initialValue = elementsAttr.getValues()[0]; } uint64_t alignment = global.alignment().getValueOr(0); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2415,8 +2415,7 @@ // AffineMinMaxOpBase //===----------------------------------------------------------------------===// -template -static LogicalResult verifyAffineMinMaxOp(T op) { +template static LogicalResult verifyAffineMinMaxOp(T op) { // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) return op.emitOpError( @@ -2424,8 +2423,7 @@ return success(); } -template -static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { +template static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { p << ' ' << op->getAttr(T::getMapAttrName()); auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); @@ -2532,8 +2530,7 @@ /// /// %1 = affine.min affine_map< /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] -template -struct MergeAffineMinMaxOp : public OpRewritePattern { +template struct MergeAffineMinMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T affineOp, @@ -2890,19 +2887,19 @@ } AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { + auto values = lowerBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += lowerBoundsGroups().getValue(i); - return lowerBoundsMap().getSliceMap( - start, lowerBoundsGroups().getValue(pos)); + start += values[i]; + return lowerBoundsMap().getSliceMap(start, values[pos]); } AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { + auto values = upperBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += upperBoundsGroups().getValue(i); - return upperBoundsMap().getSliceMap( - start, upperBoundsGroups().getValue(pos)); + start += values[i]; + return upperBoundsMap().getSliceMap(start, values[pos]); } AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -163,8 +163,8 @@ /// Returns the value that corresponds to named position `pos` from the /// attribute `attr` assuming it's a dense integer elements attribute. static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) { - return attr.cast().getValue( - static_cast(pos)); + return attr.cast() + .getValues()[static_cast(pos)]; } /// Returns the part of the data layout entry that corresponds to `pos` for the diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1184,7 +1184,7 @@ if (matchPattern(def, m_Constant(&splatAttr)) && splatAttr.isSplat() && splatAttr.getType().getElementType().isIntOrFloat()) { - constantAttr = splatAttr.getSplatValue(); + constantAttr = splatAttr.getSplatValue(); return true; } } @@ -1455,10 +1455,9 @@ bool isFloat = elementType.isa(); if (isFloat) { - SmallVector> - inputFpIterators; + SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) - inputFpIterators.push_back(inputValues[i].getValues()); + inFpRanges.push_back(inputValues[i].getValues()); computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); @@ -1469,22 +1468,17 @@ computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apFloats[i] = - *(inputFpIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue(); + fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; } } else { - SmallVector> - inputIntIterators; + SmallVector> inIntRanges; for (int i = 0; i < numInputs; ++i) - inputIntIterators.push_back(inputValues[i].getValues()); + inIntRanges.push_back(inputValues[i].getValues()); computeFnInputs.apInts.resize(numInputs); @@ -1495,25 +1489,19 @@ computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apInts[i] = - *(inputIntIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - intOutputValues[dstLinearIndex] = outputs.apInt.getValue(); + intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; } } - DenseIntOrFPElementsAttr outputAttr; - if (isFloat) { - outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); - } else { - outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); - } + DenseElementsAttr outputAttr = + isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) + : DenseElementsAttr::get(outputType, intOutputValues); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -913,9 +913,9 @@ loc, newOutputType, output, ioReshapeIndices); // We need to shrink the strides and dilations too. - auto stride = convOp.strides().getFlatValue(removeH ? 1 : 0); + auto stride = convOp.strides().getValues()[removeH ? 1 : 0]; auto stridesAttr = rewriter.getI64VectorAttr(stride); - auto dilation = convOp.dilations().getFlatValue(removeH ? 1 : 0); + auto dilation = convOp.dilations().getValues()[removeH ? 1 : 0]; auto dilationsAttr = rewriter.getI64VectorAttr(dilation); auto conv1DOp = rewriter.create( diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -56,7 +56,7 @@ if (auto vector = composite.dyn_cast()) { assert(indices.size() == 1 && "must have exactly one index for a vector"); - return vector.getValue({indices[0]}); + return vector.getValues()[indices[0]]; } if (auto array = composite.dyn_cast()) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1138,7 +1138,7 @@ return nullptr; if (dim.getValue() >= elements.getNumElements()) return nullptr; - return elements.getValue({(uint64_t)dim.getValue()}); + return elements.getValues()[(uint64_t)dim.getValue()]; } void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1304,13 +1304,14 @@ if (!caseValues) return; - for (int64_t i = 0, size = caseValues.size(); i < size; ++i) { + for (const auto &it : llvm::enumerate(caseValues.getValues())) { p << ','; p.printNewline(); p << " "; - p << caseValues.getValue(i).getLimitedValue(); + p << it.value().getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]); + p.printSuccessorAndUseList(caseDestinations[it.index()], + caseOperands[it.index()]); } p.printNewline(); } @@ -1353,9 +1354,9 @@ SuccessorRange caseDests = getCaseDestinations(); if (auto value = operands.front().dyn_cast_or_null()) { - for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i) - if (value == caseValues->getValue(i)) - return caseDests[i]; + for (const auto &it : llvm::enumerate(caseValues->getValues())) + if (it.value() == value.getValue()) + return caseDests[it.index()]; return getDefaultDestination(); } return nullptr; @@ -1394,15 +1395,15 @@ auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseDests[i] == op.getDefaultDestination() && - op.getCaseOperands(i) == op.getDefaultOperands()) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseDests[it.index()] == op.getDefaultDestination() && + op.getCaseOperands(it.index()) == op.getDefaultOperands()) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) @@ -1424,10 +1425,11 @@ static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, APInt caseValue) { auto caseValues = op.getCaseValues(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValues->getValue(i) == caseValue) { - rewriter.replaceOpWithNewOp(op, op.getCaseDestinations()[i], - op.getCaseOperands(i)); + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (it.value() == caseValue) { + rewriter.replaceOpWithNewOp( + op, op.getCaseDestinations()[it.index()], + op.getCaseOperands(it.index())); return; } } @@ -1551,22 +1553,16 @@ return failure(); // Fold this switch to an unconditional branch. - APInt caseValue; - bool isDefault = true; SuccessorRange predDests = predSwitch.getCaseDestinations(); - Optional predCaseValues = predSwitch.getCaseValues(); - for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { - if (currentBlock == predDests[i]) { - caseValue = predCaseValues->getValue(i); - isDefault = false; - break; - } - } - if (isDefault) + auto it = llvm::find(predDests, currentBlock); + if (it != predDests.end()) { + Optional predCaseValues = predSwitch.getCaseValues(); + foldSwitch(op, rewriter, + predCaseValues->getValues()[it - predDests.begin()]); + } else { rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), op.getDefaultOperands()); - else - foldSwitch(op, rewriter, caseValue); + } return success(); } @@ -1613,7 +1609,7 @@ auto predCaseValues = predSwitch.getCaseValues(); for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) if (currentBlock != predDests[i]) - caseValuesToRemove.insert(predCaseValues->getValue(i)); + caseValuesToRemove.insert(predCaseValues->getValues()[i]); SmallVector newCaseDestinations; SmallVector newCaseOperands; @@ -1622,14 +1618,14 @@ auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValuesToRemove.contains(caseValues->getValue(i))) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseValuesToRemove.contains(it.value())) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -340,7 +340,7 @@ // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) - return splatTensor.getSplatValue(); + return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; @@ -353,7 +353,7 @@ // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValue(indices); + return elementsAttr.getValues()[indices]; return {}; } @@ -440,7 +440,7 @@ Attribute dest = operands[1]; if (scalar && dest) if (auto splatDest = dest.dyn_cast()) - if (scalar == splatDest.getSplatValue()) + if (scalar == splatDest.getSplatValue()) return dest; return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -230,6 +230,7 @@ // Transpose the input constant. Because we don't know its rank in advance, // we need to loop over the range [0, element count) and delinearize the // index. + auto attrValues = inputValues.getValues(); for (int srcLinearIndex = 0; srcLinearIndex < numElements; ++srcLinearIndex) { SmallVector srcIndices(inputType.getRank(), 0); @@ -247,7 +248,7 @@ for (int dim = 1; dim < outputType.getRank(); ++dim) dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - outputValues[dstLinearIndex] = inputValues.getValue(srcIndices); + outputValues[dstLinearIndex] = attrValues[srcIndices]; } rewriter.replaceOpWithNewOp( @@ -424,8 +425,7 @@ // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template -static LogicalResult verifyConvOp(T op) { +template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = op.input().getType().template dyn_cast(); auto weightType = op.weight().getType().template dyn_cast(); diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -56,15 +56,15 @@ return isValidIndex(elementsAttr.getType().cast(), index); } -uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr, - ArrayRef index) { - ShapedType type = elementsAttr.getType().cast(); - assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); +uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { + ShapedType shapeType = type.cast(); + assert(isValidIndex(shapeType, index) && + "expected valid multi-dimensional index"); // Reduce the provided multidimensional index into a flattended 1D row-major // index. - auto rank = type.getRank(); - auto shape = type.getShape(); + auto rank = shapeType.getRank(); + ArrayRef shape = shapeType.getShape(); uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; for (int i = rank - 1; i >= 0; --i) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -902,10 +902,10 @@ } auto DenseElementsAttr::getComplexIntValues() const - -> llvm::iterator_range { + -> iterator_range_impl { assert(isComplexOfIntType(getElementType()) && "expected complex integral type"); - return {ComplexIntElementIterator(*this, 0), + return {getType(), ComplexIntElementIterator(*this, 0), ComplexIntElementIterator(*this, getNumElements())}; } auto DenseElementsAttr::complex_value_begin() const @@ -923,10 +923,10 @@ /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. auto DenseElementsAttr::getFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { auto elementType = getElementType().cast(); const auto &elementSemantics = elementType.getFloatSemantics(); - return {FloatElementIterator(elementSemantics, raw_int_begin()), + return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; } auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { @@ -939,11 +939,12 @@ } auto DenseElementsAttr::getComplexFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { Type eltTy = getElementType().cast().getElementType(); assert(eltTy.isa() && "expected complex float type"); const auto &semantics = eltTy.cast().getFloatSemantics(); - return {{semantics, {*this, 0}}, + return {getType(), + {semantics, {*this, 0}}, {semantics, {*this, static_cast(getNumElements())}}}; } auto DenseElementsAttr::complex_float_value_begin() const @@ -1248,13 +1249,6 @@ // OpaqueElementsAttr //===----------------------------------------------------------------------===// -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - return Attribute(); -} - bool OpaqueElementsAttr::decode(ElementsAttr &result) { Dialect *dialect = getDialect().getDialect(); if (!dialect) @@ -1279,47 +1273,6 @@ // SparseElementsAttr //===----------------------------------------------------------------------===// -/// Return the value of the element at the given index. -Attribute SparseElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // The sparse indices are 64-bit integers, so we can reinterpret the raw data - // as a 1-D index array. - auto sparseIndices = getIndices(); - auto sparseIndexValues = sparseIndices.getValues(); - - // Check to see if the indices are a splat. - if (sparseIndices.isSplat()) { - // If the index is also not a splat of the index value, we know that the - // value is zero. - auto splatIndex = *sparseIndexValues.begin(); - if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) - return getZeroAttr(); - - // If the indices are a splat, we also expect the values to be a splat. - assert(getValues().isSplat() && "expected splat values"); - return getValues().getSplatValue(); - } - - // Build a mapping between known indices and the offset of the stored element. - llvm::SmallDenseMap, size_t> mappedIndices; - auto numSparseIndices = sparseIndices.getType().getDimSize(0); - size_t rank = type.getRank(); - for (size_t i = 0, e = numSparseIndices; i != e; ++i) - mappedIndices.try_emplace( - {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); - - // Look for the provided index key within the mapped indices. If the provided - // index is not found, then return a zero attribute. - auto it = mappedIndices.find(index); - if (it == mappedIndices.end()) - return getZeroAttr(); - - // Otherwise, return the held sparse value element. - return getValues().getValue(it->second); -} - /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { auto eltType = getElementType().cast(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -71,7 +71,7 @@ return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() - .getFlatValue(index) + .getValues()[index] .getSExtValue(); auto *stc = val.get(); return stc->getDims()[index]; diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -386,14 +386,12 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.getBranchWeights(); llvm::MDNode *branchWeights = nullptr; - if (weights) { + if (auto weights = condbrOp.getBranchWeights()) { // Map weight attributes to LLVM metadata. - auto trueWeight = - weights.getValue().getValue(0).cast().getInt(); - auto falseWeight = - weights.getValue().getValue(1).cast().getInt(); + auto weightValues = weights->getValues(); + auto trueWeight = weightValues[0].getSExtValue(); + auto falseWeight = weightValues[1].getSExtValue(); branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) .createBranchWeights(static_cast(trueWeight), diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -706,11 +706,12 @@ if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { return attr.getType().getElementType().isInteger(1) - ? prepareConstantBool(loc, attr.getValue(index)) - : prepareConstantInt(loc, attr.getValue(index)); + ? prepareConstantBool(loc, attr.getValues()[index]) + : prepareConstantInt(loc, + attr.getValues()[index]); } if (auto attr = valueAttr.dyn_cast()) { - return prepareConstantFp(loc, attr.getValue(index)); + return prepareConstantFp(loc, attr.getValues()[index]); } return 0; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -154,9 +154,9 @@ # ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes(); # IMPL: getSymbolBindings(Test2Op self) -# IMPL: cst2 = self.strides().getValue({ 0 }); +# IMPL: cst2 = self.strides().getValues()[0]; # IMPL-NEXT: getAffineConstantExpr(cst2, context) -# IMPL: cst3 = self.strides().getValue({ 1 }); +# IMPL: cst3 = self.strides().getValues()[1]; # IMPL-NEXT: getAffineConstantExpr(cst3, context) # IMPL: Test2Op::indexing_maps() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -142,8 +142,7 @@ /// Top-level type containing op metadata and one of a concrete op type. /// Currently, the only defined op type is `structured_op` (maps to /// `LinalgStructuredOpConfig`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOpConfig &info) { io.mapOptional("metadata", info.metadata); io.mapOptional("structured_op", info.structuredOp); @@ -156,8 +155,7 @@ /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgStructuredOpConfig &info) { io.mapRequired("args", info.args); io.mapRequired("indexing_maps", info.indexingMaps); @@ -180,8 +178,7 @@ /// attribute symbols. During op creation these symbols are replaced by the /// corresponding `name` attribute values. Only attribute arguments have /// an `attribute_map`. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); @@ -192,8 +189,7 @@ }; /// Usage enum for a named argument. -template <> -struct ScalarEnumerationTraits { +template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefUsage &value) { io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); @@ -202,8 +198,7 @@ }; /// Iterator type enum. -template <> -struct ScalarEnumerationTraits { +template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgIteratorTypeDef &value) { io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); @@ -211,8 +206,7 @@ }; /// Metadata about the op (name, C++ name, and documentation). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOpMetadata &info) { io.mapRequired("name", info.name); io.mapRequired("cpp_class_name", info.cppClassName); @@ -226,8 +220,7 @@ /// some symbols that bind to attributes of the op. Each indexing map must /// be normalized over the same list of dimensions, and its symbols must /// match the symbols for argument shapes. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgIndexingMapsConfig &info) { io.mapOptional("static_indexing_maps", info.staticIndexingMaps); } @@ -237,8 +230,7 @@ /// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarAssign &info) { io.mapRequired("arg", info.arg); io.mapRequired("value", info.value); @@ -250,8 +242,7 @@ /// - `scalar_apply`: Result of evaluating a named function (see /// `ScalarApply`). /// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); @@ -266,16 +257,14 @@ /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarApply &info) { io.mapRequired("fn_name", info.fnName); io.mapRequired("operands", info.operands); } }; -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarSymbolicCast &info) { io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); @@ -285,8 +274,7 @@ /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. -template <> -struct ScalarTraits { +template <> struct ScalarTraits { static void output(const SerializedAffineMap &value, void *rawYamlContext, raw_ostream &out) { assert(value.affineMapAttr); @@ -726,7 +714,7 @@ // {1}: Symbol position // {2}: Attribute index static const char structuredOpAccessAttrFormat[] = R"FMT( -int64_t cst{1} = self.{0}().getValue({ {2} }); +int64_t cst{1} = self.{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; // Update all symbol bindings mapped to an attribute. diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -113,7 +113,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -138,7 +139,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -162,7 +164,8 @@ EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -202,7 +202,7 @@ RankedTensorType shape = RankedTensorType::get({}, intTy); auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); - EXPECT_TRUE(attr.getValue({0}) == value); + EXPECT_TRUE(attr.getValues()[0] == value); } TEST(SparseElementsAttrTest, GetZero) { @@ -238,15 +238,15 @@ // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. - auto zeroIntValue = sparseInt.getValue({1, 1}); + auto zeroIntValue = sparseInt.getValues()[{1, 1}]; EXPECT_EQ(zeroIntValue.cast().getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); - auto zeroFloatValue = sparseFloat.getValue({1, 1}); + auto zeroFloatValue = sparseFloat.getValues()[{1, 1}]; EXPECT_EQ(zeroFloatValue.cast().getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); - auto zeroStringValue = sparseString.getValue({1, 1}); + auto zeroStringValue = sparseString.getValues()[{1, 1}]; EXPECT_TRUE(zeroStringValue.cast().getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); }