diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1318,17 +1318,17 @@ while (counter < indexingMap.getNumResults()) { unsigned dim = indexingMap.getResult(counter).cast().getPosition(); + // This is the start of a collapsed dimensions of the iteration that + // is gauranteed to be preserved in the indexing map. The number of folded + // dims is obtained from the collapsed op to original op mapping. + unsigned numFoldedDims = + collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] + .size(); if (origOpToCollapsedOpMapping[dim].second == 0) { - // This is the start of a collapsed dimensions of the iteration that - // is gauranteed to be preserved in the indexing map. The number of folded - // dims is obtained from the collapsed op to original op mapping. - unsigned numFoldedDims = - collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] - .size(); auto range = llvm::seq(counter, counter + numFoldedDims); operandReassociation.emplace_back(range.begin(), range.end()); - counter += numFoldedDims; } + counter += numFoldedDims; } return operandReassociation; }