diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -385,8 +385,31 @@ makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); resultExprs.push_back(linearizedExpr); } - return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), - resultExprs, context); + // Number of symbols may change as compared to what is present in the + // sourceMap as stride length may be constant or symbolic. + unsigned maxSymbolPosition = 0; + bool hasEncounteredSymbol = false; + std::function updateMaxSymbolPosition = + [&](AffineExpr expr) { + auto binOpExpr = expr.dyn_cast(); + if (binOpExpr) { + updateMaxSymbolPosition(binOpExpr.getLHS()); + updateMaxSymbolPosition(binOpExpr.getRHS()); + return; + } + if (auto symbolExpr = expr.dyn_cast()) { + maxSymbolPosition = + std::max(maxSymbolPosition, symbolExpr.getPosition()); + hasEncounteredSymbol = true; + } + }; + + for (AffineExpr expr : resultExprs) + updateMaxSymbolPosition(expr); + + unsigned numSymbol = hasEncounteredSymbol ? 1 + maxSymbolPosition : 0; + return AffineMap::get(sourceMap.getNumDims(), numSymbol, resultExprs, + context); } // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a