diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -196,6 +196,8 @@ // IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); // IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context); // IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context); +// IMPL-LABEL: void Test9Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Value [[a:.*]](args[0]), [[c:.*]](args[2]); ods_def: def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -35,6 +35,7 @@ #include "llvm/Support/ToolOutputFile.h" #include +#include #define DEBUG_TYPE "linalg-ods-gen" @@ -2342,14 +2343,14 @@ (linalg_yield(ValueRange{ {3} })); })FMT"; - unsigned idx = 0; std::string valueHandleStr; llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); - llvm::interleaveComma( - llvm::seq(0, state.numArgs), valueHandleStringStream, [&](auto) { - valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; - idx++; - }); + std::set usedTensorId; + for (const auto &iter : state.orderedTensorArgs) + usedTensorId.insert(iter.second); + llvm::interleaveComma(usedTensorId, valueHandleStringStream, [&](auto idx) { + valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; + }); std::string expressionsStr; llvm::raw_string_ostream expressionStringStream(expressionsStr);