diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -385,8 +385,27 @@ makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); resultExprs.push_back(linearizedExpr); } - return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), - resultExprs, context); + // Number of symbols may change as compared to what is present in the + // sourceMap as stride length may be constant or symbolic. + unsigned numSymbols = 0; + std::function updateMaxSymbolPosition = + [&](AffineExpr expr) { + auto binOpExpr = expr.dyn_cast(); + if (binOpExpr) { + updateMaxSymbolPosition(binOpExpr.getLHS()); + updateMaxSymbolPosition(binOpExpr.getRHS()); + return; + } + if (auto symbolExpr = expr.dyn_cast()) { + numSymbols = std::max(numSymbols, 1 + symbolExpr.getPosition()); + } + }; + + for (AffineExpr expr : resultExprs) + updateMaxSymbolPosition(expr); + + return AffineMap::get(sourceMap.getNumDims(), numSymbols, resultExprs, + context); } // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a