diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -496,17 +496,20 @@ return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); - auto oldResultType = padOp.getResultType(); - SmallVector staticShape(oldResultType.getRank(), - ShapedType::kDynamicSize); + SmallVector newShape = + getAsOpFoldResult(reifiedShape.front()); auto emptyTensor = rewriter.create( - padOp.getLoc(), staticShape, oldResultType.getElementType(), - reifiedShape.front()); - auto newFillOp = rewriter.create( - fillOp.getLoc(), ValueRange{padValue}, ValueRange{emptyTensor}); - rewriter.replaceOpWithNewOp(padOp, oldResultType, - newFillOp.result()); - + padOp.getLoc(), newShape, padOp.getResultType().getElementType()); + Value replacement = + rewriter + .create(fillOp.getLoc(), ValueRange{padValue}, + ValueRange{emptyTensor}) + .getResult(0); + if (replacement.getType() != padOp.getResultType()) { + replacement = rewriter.create( + fillOp.getLoc(), padOp.getResultType(), replacement); + } + rewriter.replaceOp(padOp, replacement); return success(); } };