diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1214,13 +1214,17 @@ // values at the constant index dropped. SmallVector fusedIndexMaps; SmallVector fusedOperands; + SmallVector fusedLocs{genericOp.getLoc()}; fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); fusedOperands.reserve(genericOp.getNumInputs()); + fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); for (OpOperand *inputOperand : genericOp.getInputOperands()) { if (inputOperand == opOperand) continue; + Value inputValue = inputOperand->get(); fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); - fusedOperands.push_back(inputOperand->get()); + fusedOperands.push_back(inputValue); + fusedLocs.push_back(inputValue.getLoc()); } for (OpOperand *outputOperand : genericOp.getOutputOperands()) fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); @@ -1237,7 +1241,7 @@ SmallVector outputOperands = genericOp.getOutputOperands(); auto fusedOp = rewriter.create( - rewriter.getUnknownLoc(), genericOp->getResultTypes(), + rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), /*inputs=*/fusedOperands, /*outputs=*/outputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),