diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -751,6 +751,13 @@ padOp.getLoc(), vecType, padOp.source(), readIndices, padValue, readInBounds); + // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire + // tensor, write directly to the FillOp's operand. + if (llvm::equal(vecShape, resultType.getShape()) + && llvm::all_of(writeInBounds, [](bool b) { return b; })) + if (auto fill = dest.getDefiningOp()) + dest = fill.output(); + // Generate TransferWriteOp. auto writeIndices = ofrToIndexValues( rewriter, padOp.getLoc(), padOp.getMixedLowPad());