diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -755,26 +755,25 @@ using type = std::tuple())...>; }; -template +template using zip_traits = iterator_facade_base< ZipType, std::common_type_t< std::bidirectional_iterator_tag, typename std::iterator_traits::iterator_category...>, // ^ TODO: Implement random access methods. - typename ZipTupleType::type, + ReferenceTupleType, typename std::iterator_traits< std::tuple_element_t<0, std::tuple>>::difference_type, // ^ FIXME: This follows boost::make_zip_iterator's assumption that all // inner iterators have the same difference_type. It would fail if, for // instance, the second field's difference_type were non-numeric while the // first is. - typename ZipTupleType::type *, - typename ZipTupleType::type>; + ReferenceTupleType *, ReferenceTupleType>; -template -struct zip_common : public zip_traits { - using Base = zip_traits; +template +struct zip_common : public zip_traits { + using Base = zip_traits; using IndexSequence = std::index_sequence_for; using value_type = typename Base::value_type; @@ -824,8 +823,10 @@ }; template -struct zip_first : zip_common, Iters...> { - using zip_common::zip_common; +struct zip_first : zip_common, + typename ZipTupleType::type, Iters...> { + using zip_common::type, + Iters...>::zip_common; bool operator==(const zip_first &other) const { return std::get<0>(this->iterators) == std::get<0>(other.iterators); @@ -833,8 +834,11 @@ }; template -struct zip_shortest : zip_common, Iters...> { - using zip_common::zip_common; +struct zip_shortest + : zip_common, typename ZipTupleType::type, + Iters...> { + using zip_common::type, + Iters...>::zip_common; bool operator==(const zip_shortest &other) const { return any_iterator_equals(other, std::index_sequence_for{}); @@ -2213,113 +2217,182 @@ namespace detail { -template class enumerator_iter; +/// Tuple-like type for `zip_enumerator` dereference. +template struct enumerator_result; -template struct result_pair { - using value_reference = - typename std::iterator_traits>::reference; - - friend class enumerator_iter; - - result_pair(std::size_t Index, IterOfRange Iter) - : Index(Index), Iter(Iter) {} - - std::size_t index() const { return Index; } - value_reference value() const { return *Iter; } - -private: - std::size_t Index = std::numeric_limits::max(); - IterOfRange Iter; -}; - -template -decltype(auto) get(const result_pair &Pair) { - static_assert(i < 2); - if constexpr (i == 0) { - return Pair.index(); - } else { - return Pair.value(); - } -} - -template -class enumerator_iter - : public iterator_facade_base, std::forward_iterator_tag, - const result_pair> { - using result_type = result_pair; - -public: - explicit enumerator_iter(IterOfRange EndIter) - : Result(std::numeric_limits::max(), EndIter) {} - - enumerator_iter(std::size_t Index, IterOfRange Iter) - : Result(Index, Iter) {} - - const result_type &operator*() const { return Result; } +template +using EnumeratorTupleType = enumerator_result())...>; + +/// Zippy iterator that uses the second iterator for comparisons. For the +/// increment to be safe, the second range has to be the shortest. +/// Returns `enumerator_result` on dereference to provide `.index()` and +/// `.value()` member functions. +/// Note: Because the dereference operator returns `enumerator_result` as a +/// value instead of a reference and does not strictly conform to the C++17's +/// definition of forward iterator. However, it satisfies all the +/// forward_iterator requirements that the `zip_common` and `zippy` depend on +/// and fully conforms to the C++20 definition of forward iterator. +/// This is similar to `std::vector::iterator` that returns bit reference +/// wrappers on dereference. +template +struct zip_enumerator : zip_common, + EnumeratorTupleType, Iters...> { + static_assert(sizeof...(Iters) >= 2, "Expected at least two iteratees"); + using zip_common, EnumeratorTupleType, + Iters...>::zip_common; - enumerator_iter &operator++() { - assert(Result.Index != std::numeric_limits::max()); - ++Result.Iter; - ++Result.Index; - return *this; + bool operator==(const zip_enumerator &Other) const { + return std::get<1>(this->iterators) == std::get<1>(Other.iterators); } +}; - bool operator==(const enumerator_iter &RHS) const { - // Don't compare indices here, only iterators. It's possible for an end - // iterator to have different indices depending on whether it was created - // by calling std::end() versus incrementing a valid iterator. - return Result.Iter == RHS.Result.Iter; +template struct enumerator_result { + static constexpr std::size_t NumRefs = sizeof...(Refs); + static_assert(NumRefs != 0); + // `NumValues` includes the index. + static constexpr std::size_t NumValues = NumRefs + 1; + + // Tuple type whose element types are references for each `Ref`. + using range_reference_tuple = std::tuple; + // Tuple type who elements are references to all values, including both + // the index and `Refs` reference types. + using value_reference_tuple = std::tuple; + + enumerator_result(std::size_t Index, Refs &&...Rs) + : Idx(Index), Storage(std::forward(Rs)...) {} + + /// Returns the 0-based index of the current position within the original + /// input range(s). + std::size_t index() const { return Idx; } + + /// Returns the value(s) for the current iterator. This does not include the + /// index. + decltype(auto) value() const { + if constexpr (NumRefs == 1) + return std::get<0>(Storage); + else + return Storage; + } + + /// Returns the value at index `I`. This includes the index. + template + friend decltype(auto) get(const enumerator_result &Result) { + static_assert(I < NumValues, "Index out of bounds"); + if constexpr (I == 0) + return Result.Idx; + else + return std::get(Result.Storage); + } + + template + friend bool operator==(const enumerator_result &Result, + const std::tuple &Other) { + static_assert(NumRefs == sizeof...(Ts), "Size mismatch"); + if (Result.Idx != std::get<0>(Other)) + return false; + return Result.is_value_equal(Other, std::make_index_sequence{}); } private: - result_type Result; + template + bool is_value_equal(const Tuple &Other, std::index_sequence) const { + return ((std::get(Storage) == std::get(Other)) && ...); + } + + std::size_t Idx; + // Make this tuple mutable to avoid casts that obfuscate const-correctness + // issues. Const-correctness of references is taken care of by `zippy` that + // defines const-non and const iterator types that will propagate down to + // `enumerator_result`'s `Refs`. + // Note that unlike the results of `zip*` functions, `enumerate`'s result are + // supposed to be modifiable even when defined as + // `const`. + mutable range_reference_tuple Storage; }; -template class enumerator { -public: - explicit enumerator(R &&Range) : TheRange(std::forward(Range)) {} +/// Infinite stream of increasing 0-based `size_t` indices. +struct index_stream { + struct iterator : iterator_facade_base { + iterator &operator++() { + assert(Index != std::numeric_limits::max() && + "Attempting to increment end iterator"); + ++Index; + return *this; + } - enumerator_iter begin() { - return enumerator_iter(0, adl_begin(TheRange)); - } - enumerator_iter begin() const { - return enumerator_iter(0, adl_begin(TheRange)); - } + // Note: This dereference operator returns a value instead of a reference + // and does not strictly conform to the C++17's definition of forward + // iterator. However, it satisfies all the forward_iterator requirements + // that the `zip_common` depends on and fully conforms to the C++20 + // definition of forward iterator. + std::size_t operator*() const { return Index; } - enumerator_iter end() { return enumerator_iter(adl_end(TheRange)); } - enumerator_iter end() const { - return enumerator_iter(adl_end(TheRange)); - } + friend bool operator==(const iterator &Lhs, const iterator &Rhs) { + return Lhs.Index == Rhs.Index; + } -private: - R TheRange; + std::size_t Index = 0; + }; + + iterator begin() const { return {}; } + iterator end() const { + // We approximate 'infinity' with the max size_t value, which should be good + // enough to index over any container. + iterator It; + It.Index = std::numeric_limits::max(); + return It; + } }; } // end namespace detail -/// Given an input range, returns a new range whose values are are pair (A,B) -/// such that A is the 0-based index of the item in the sequence, and B is -/// the value from the original sequence. Example: +/// Given two or more input ranges, returns a new range whose values are are +/// tuples (A, B, C, ...), such that A is the 0-based index of the item in the +/// sequence, and B, C, ..., are the values from the original input ranges. All +/// input ranges are required to have equal lengths. Note that the returned +/// iterator allows for the values (B, C, ...) to be modified. Example: +/// +/// ```c++ +/// std::vector Letters = {'A', 'B', 'C', 'D'}; +/// std::vector Vals = {10, 11, 12, 13}; /// -/// std::vector Items = {'A', 'B', 'C', 'D'}; -/// for (auto X : enumerate(Items)) { -/// printf("Item %zu - %c\n", X.index(), X.value()); +/// for (auto [Index, Letter, Value] : enumerate(Letters, Vals)) { +/// printf("Item %zu - %c: %d\n", Index, Letter, Value); +/// Value -= 10; /// } +/// ``` /// -/// or using structured bindings: +/// Output: +/// Item 0 - A: 10 +/// Item 1 - B: 11 +/// Item 2 - C: 12 +/// Item 3 - D: 13 /// -/// for (auto [Index, Value] : enumerate(Items)) { -/// printf("Item %zu - %c\n", Index, Value); +/// or using an iterator: +/// ```c++ +/// for (auto it : enumerate(Vals)) { +/// it.value() += 10; +/// printf("Item %zu: %d\n", it.index(), it.value()); /// } +/// ``` /// /// Output: -/// Item 0 - A -/// Item 1 - B -/// Item 2 - C -/// Item 3 - D +/// Item 0: 20 +/// Item 1: 21 +/// Item 2: 22 +/// Item 3: 23 /// -template detail::enumerator enumerate(R &&TheRange) { - return detail::enumerator(std::forward(TheRange)); +template +auto enumerate(FirstRange &&First, RestRanges &&...Rest) { + assert((sizeof...(Rest) == 0 || + all_equal({std::distance(adl_begin(First), adl_end(First)), + std::distance(adl_begin(Rest), adl_end(Rest))...})) && + "Ranges have different length"); + using enumerator = detail::zippy; + return enumerator(detail::index_stream{}, std::forward(First), + std::forward(Rest)...); } namespace detail { @@ -2451,15 +2524,17 @@ } // end namespace llvm namespace std { -template -struct tuple_size> - : std::integral_constant {}; +template +struct tuple_size> + : std::integral_constant {}; -template -struct tuple_element> - : std::conditional::value_reference> { -}; +template +struct tuple_element> + : std::tuple_element> {}; + +template +struct tuple_element> + : std::tuple_element> {}; } // namespace std diff --git a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h --- a/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h +++ b/llvm/lib/Target/AArch64/AArch64PerfectShuffle.h @@ -6590,11 +6590,11 @@ assert(M.size() == 4 && "Expected a 4 entry perfect shuffle"); // Special case zero-cost nop copies, from either LHS or RHS. - if (llvm::all_of(llvm::enumerate(M), [](auto &E) { + if (llvm::all_of(llvm::enumerate(M), [](const auto &E) { return E.value() < 0 || E.value() == (int)E.index(); })) return 0; - if (llvm::all_of(llvm::enumerate(M), [](auto &E) { + if (llvm::all_of(llvm::enumerate(M), [](const auto &E) { return E.value() < 0 || E.value() == (int)E.index() + 4; })) return 0; diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -1249,7 +1249,7 @@ const MCInstrDesc &MCID = MI->getDesc(); bool IsUse = false; unsigned LastOpIdx = MI->getNumOperands() - 1; - for (auto &Op : enumerate(reverse(MCID.operands()))) { + for (const auto &Op : enumerate(reverse(MCID.operands()))) { const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index()); if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR) continue; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1651,11 +1651,11 @@ StringRef &ErrInfo) const { MCInstrDesc const &Desc = MI.getDesc(); - for (auto &OI : enumerate(Desc.operands())) { - unsigned OpType = OI.value().OperandType; + for (const auto &[Index, Operand] : enumerate(Desc.operands())) { + unsigned OpType = Operand.OperandType; if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM && OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) { - const MachineOperand &MO = MI.getOperand(OI.index()); + const MachineOperand &MO = MI.getOperand(Index); if (MO.isImm()) { int64_t Imm = MO.getImm(); bool Ok; diff --git a/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp b/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp --- a/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp +++ b/llvm/tools/llvm-mca/Views/InstructionInfoView.cpp @@ -55,10 +55,7 @@ } } - int Index = 0; - for (const auto &I : enumerate(zip(IIVD, Source))) { - const InstructionInfoViewData &IIVDEntry = std::get<0>(I.value()); - + for (const auto &[Index, IIVDEntry, Inst] : enumerate(IIVD, Source)) { TempStream << ' ' << IIVDEntry.NumMicroOpcodes << " "; if (IIVDEntry.NumMicroOpcodes < 10) TempStream << " "; @@ -92,7 +89,7 @@ } if (PrintEncodings) { - StringRef Encoding(CE.getEncoding(I.index())); + StringRef Encoding(CE.getEncoding(Index)); unsigned EncodingSize = Encoding.size(); TempStream << " " << EncodingSize << (EncodingSize < 10 ? " " : " "); @@ -104,9 +101,7 @@ FOS.flush(); } - const MCInst &Inst = std::get<1>(I.value()); TempStream << printInstructionString(Inst) << '\n'; - ++Index; } TempStream.flush(); diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -13,8 +13,11 @@ #include #include +#include +#include #include #include +#include #include #include @@ -153,6 +156,131 @@ PairType(2u, '4'))); } +TEST(STLExtrasTest, EnumerateTwoRanges) { + using Tuple = std::tuple; + + std::vector Ints = {1, 2}; + std::vector Bools = {true, false}; + EXPECT_THAT(llvm::enumerate(Ints, Bools), + ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false))); + + // Check that we can modify the values when the temporary is a const + // reference. + for (const auto &[Idx, Int, Bool] : llvm::enumerate(Ints, Bools)) { + (void)Idx; + Bool = false; + Int = -1; + } + + EXPECT_THAT(Ints, ElementsAre(-1, -1)); + EXPECT_THAT(Bools, ElementsAre(false, false)); + + // Check that we can modify the values when the result gets copied. + for (auto [Idx, Bool, Int] : llvm::enumerate(Bools, Ints)) { + (void)Idx; + Int = 3; + Bool = true; + } + + EXPECT_THAT(Ints, ElementsAre(3, 3)); + EXPECT_THAT(Bools, ElementsAre(true, true)); + + // Check that we can modify the values through `.value()`. + size_t Iters = 0; + for (auto It : llvm::enumerate(Bools, Ints)) { + EXPECT_EQ(It.index(), Iters); + ++Iters; + + std::get<0>(It.value()) = false; + std::get<1>(It.value()) = 4; + } + + EXPECT_THAT(Ints, ElementsAre(4, 4)); + EXPECT_THAT(Bools, ElementsAre(false, false)); +} + +TEST(STLExtrasTest, EnumerateThreeRanges) { + using Tuple = std::tuple; + + std::vector Ints = {1, 2}; + std::vector Bools = {true, false}; + char Chars[] = {'X', 'D'}; + EXPECT_THAT(llvm::enumerate(Ints, Bools, Chars), + ElementsAre(Tuple(0, 1, true, 'X'), Tuple(1, 2, false, 'D'))); + + for (auto [Idx, Int, Bool, Char] : llvm::enumerate(Ints, Bools, Chars)) { + (void)Idx; + Int = 0; + Bool = true; + Char = '!'; + } + + EXPECT_THAT(Ints, ElementsAre(0, 0)); + EXPECT_THAT(Bools, ElementsAre(true, true)); + EXPECT_THAT(Chars, ElementsAre('!', '!')); + + // Check that we can modify the values through `.values()`. + size_t Iters = 0; + for (auto It : llvm::enumerate(Ints, Bools, Chars)) { + EXPECT_EQ(It.index(), Iters); + ++Iters; + auto [Int, Bool, Char] = It.value(); + Int = 42; + Bool = false; + Char = '$'; + } + + EXPECT_THAT(Ints, ElementsAre(42, 42)); + EXPECT_THAT(Bools, ElementsAre(false, false)); + EXPECT_THAT(Chars, ElementsAre('$', '$')); +} + +TEST(STLExtrasTest, EnumerateTemporaries) { + using Tuple = std::tuple; + + EXPECT_THAT( + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true})), + ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false), Tuple(2, 3, true))); + + size_t Iters = 0; + // This is fine from the point of view of range lifetimes because `zippy` will + // move all temporaries into its storage. No lifetime extension is necessary. + for (auto [Idx, Int, Bool] : + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true}))) { + EXPECT_EQ(Idx, Iters); + ++Iters; + Int = 0; + Bool = true; + } + + Iters = 0; + // The same thing but with the result as a const reference. + for (const auto &[Idx, Int, Bool] : + llvm::enumerate(llvm::SmallVector({1, 2, 3}), + std::vector({true, false, true}))) { + EXPECT_EQ(Idx, Iters); + ++Iters; + Int = 0; + Bool = true; + } +} + +#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) +TEST(STLExtrasTest, EnumerateDifferentLengths) { + std::vector Ints = {0, 1}; + bool Bools[] = {true, false, true}; + std::string Chars = "abc"; + EXPECT_DEATH(llvm::enumerate(Ints, Bools, Chars), + "Ranges have different length"); + EXPECT_DEATH(llvm::enumerate(Bools, Ints, Chars), + "Ranges have different length"); + EXPECT_DEATH(llvm::enumerate(Bools, Chars, Ints), + "Ranges have different length"); +} +#endif + template struct CanMove {}; template <> struct CanMove { CanMove(CanMove &&) = delete; @@ -190,8 +318,8 @@ template struct Range : Counted { using Counted::Counted; - int *begin() { return nullptr; } - int *end() { return nullptr; } + int *begin() const { return nullptr; } + int *end() const { return nullptr; } }; TEST(STLExtrasTest, EnumerateLifetimeSemanticsPRValue) { diff --git a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp --- a/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp +++ b/llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp @@ -338,9 +338,9 @@ "Must always partition into at least one partition"); TreeNode->setNumChildren(Partitioner->getNumPartitions()); - for (auto &C : enumerate(TreeNode->children())) { - SubtreeBuilders.emplace_back(&C.value(), NextInstrID); - Partitioner->applyForPartition(C.index(), *this, SubtreeBuilders.back()); + for (const auto &[Idx, Child] : enumerate(TreeNode->children())) { + SubtreeBuilders.emplace_back(&Child, NextInstrID); + Partitioner->applyForPartition(Idx, *this, SubtreeBuilders.back()); } TreeNode->setPartitioner(std::move(Partitioner)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1394,7 +1394,7 @@ // i1 = (i_{folded} / d2) % d1 // i0 = i_{folded} / (d1 * d2) llvm::DenseMap indexReplacementVals; - for (auto &foldedDims : + for (auto foldedDims : enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { ReassociationIndicesRef foldedDimsRef(foldedDims.value()); Value newIndexVal = diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1871,15 +1871,13 @@ << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) { - int64_t resultSize = std::get<0>(en.value()); - int64_t expectedSize = std::get<1>(en.value()); + for (auto [idx, resultSize, expectedSize] : + llvm::enumerate(resultType.getShape(), getStaticSizes())) { if (!ShapedType::isDynamic(resultSize) && !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize - << " in dim = " << en.index(); + << " in dim = " << idx; } // Match offset and strides in static_offset and static_strides attributes. If @@ -1900,16 +1898,14 @@ << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) { - int64_t resultStride = std::get<0>(en.value()); - int64_t expectedStride = std::get<1>(en.value()); + for (auto [idx, resultStride, expectedStride] : + llvm::enumerate(resultStrides, getStaticStrides())) { if (!ShapedType::isDynamic(resultStride) && !ShapedType::isDynamic(expectedStride) && resultStride != expectedStride) return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride - << " in dim = " << en.index(); + << " in dim = " << idx; } return success(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1891,10 +1891,7 @@ auto elseYieldArgs = op.elseYield().getOperands(); SmallVector nonHoistable; - for (const auto &it : - llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { - Value trueVal = std::get<0>(it.value()); - Value falseVal = std::get<1>(it.value()); + for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) { if (&op.getThenRegion() == trueVal.getParentRegion() || &op.getElseRegion() == falseVal.getParentRegion()) nonHoistable.push_back(trueVal.getType()); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -423,7 +423,7 @@ return b.notifyMatchFailure( op, "only support ops with one reduction dimension."); int reductionDim; - for (auto &[idx, iteratorType] : + for (auto [idx, iteratorType] : llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { if (iteratorType == utils::IteratorType::reduction) { reductionDim = idx; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1148,10 +1148,9 @@ desc.setSpecifier(newSpec); // Fills in slice information. - for (const auto &it : llvm::enumerate(llvm::zip( - op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()))) { - Dimension dim = it.index(); - auto [offset, size, stride] = it.value(); + for (auto [idx, offset, size, stride] : llvm::enumerate( + op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { + Dimension dim = idx; Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset); Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -93,7 +93,7 @@ .cast() .getPosition(); int64_t linearizedStaticDim = 1; - for (auto &d : + for (auto d : llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { if (d.index() + startPos == static_cast(dimIndex)) continue; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -762,7 +762,7 @@ return isConstantIntValue(ofr, 0); }); bool sizesMatchDestSizes = llvm::all_of( - llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) { + llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) { return getConstantIntValue(it.value()) == destType.getDimSize(it.index()); }); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -869,14 +869,14 @@ if (arg.kind != LinalgOperandDefKind::IndexAttr) continue; assert(arg.indexAttrMap); - for (auto &en : + for (auto [idx, result] : llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { - if (auto symbol = en.value().dyn_cast()) { + if (auto symbol = result.dyn_cast()) { std::string argName = arg.name; argName[0] = toupper(argName[0]); symbolBindings[symbol.getPosition()] = llvm::formatv(structuredOpAccessAttrFormat, argName, - symbol.getPosition(), en.index()); + symbol.getPosition(), idx); } } }