diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1634,20 +1634,13 @@ resultTypes[resultNumber] = resultType; Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands); - if (!resultValue.hasOneUse()) { - SmallVector results(newOp->result_begin(), newOp->result_end()); - // Create a tensor.cast operation back to the original type. - Value castBack = rewriter.create( - loc, resultValue.getType(), newOp->getResult(resultNumber)); - results[resultNumber] = castBack; - // Replace all uses except the use in the cast op that is matched by the - // pattern. Note that this cast is from a more static shape to a more - // dynamic shape. These are expected to be pulled into their consumers. - rewriter.replaceOpWithIf(linalgOp, results, - [&castOp](OpOperand &use) -> bool { - return use.getOwner() != castOp.getOperation(); - }); - } + // Create a tensor.cast operation back to the original type. + Value castBack = rewriter.create( + loc, resultValue.getType(), newOp->getResult(resultNumber)); + + SmallVector results(newOp->result_begin(), newOp->result_end()); + results[resultNumber] = castBack; + rewriter.replaceOp(linalgOp, results); rewriter.replaceOp(castOp, newOp->getResult(resultNumber)); return success(); }