diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4201,7 +4201,9 @@ writeOp.getSource().getDefiningOp(); while (defWrite) { if (checkSameValueWAW(writeOp, defWrite)) { - writeToModify.getSourceMutable().assign(defWrite.getSource()); + rewriter.updateRootInPlace(writeToModify, [&]() { + writeToModify.getSourceMutable().assign(defWrite.getSource()); + }); return success(); } if (!isDisjointTransferIndices( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -657,7 +657,8 @@ Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, elementWise, newOperands, {newWarpOp.getResult(operandIndex).getType()}); - newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), + newOp->getResult(0)); return success(); } }; @@ -695,7 +696,7 @@ Location loc = warpOp.getLoc(); rewriter.setInsertionPointAfter(warpOp); Value distConstant = rewriter.create(loc, newAttr); - warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant); + rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); return success(); } }; @@ -759,7 +760,7 @@ read.getLoc(), distributedVal.getType(), read.getSource(), indices, read.getPermutationMapAttr(), read.getPadding(), read.getMask(), read.getInBoundsAttr()); - distributedVal.replaceAllUsesWith(newRead); + rewriter.replaceAllUsesWith(distributedVal, newRead); return success(); } }; @@ -855,7 +856,7 @@ } if (!valForwarded) return failure(); - warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); + rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); return success(); } }; @@ -880,7 +881,8 @@ rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = rewriter.create( loc, destVecType, newWarpOp->getResult(newRetIndices[0])); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + broadcasted); return success(); } }; @@ -936,7 +938,8 @@ // Extract from distributed vector. Value newExtract = rewriter.create( loc, distributedVec, extractOp.getPosition()); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newExtract); return success(); } @@ -973,7 +976,8 @@ // Extract from distributed vector. Value newExtract = rewriter.create( loc, distributedVec, extractOp.getPosition()); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newExtract); return success(); } }; @@ -1031,7 +1035,8 @@ newExtract = rewriter.create(loc, distributedVec); } - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newExtract); return success(); } @@ -1056,7 +1061,7 @@ // Shuffle the extracted value to all lanes. Value shuffled = warpShuffleFromIdxFn( loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); return success(); } @@ -1104,7 +1109,8 @@ // Broadcast: Simply move the vector.inserelement op out. Value newInsert = rewriter.create( loc, newSource, distributedVec, newPos); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newInsert); return success(); } @@ -1138,7 +1144,7 @@ builder.create(loc, distributedVec); }) .getResult(0); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); } }; @@ -1184,7 +1190,8 @@ Value distributedDest = newWarpOp->getResult(newRetIndices[1]); Value newResult = rewriter.create( loc, distributedSrc, distributedDest, insertOp.getPosition()); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newResult); return success(); } @@ -1263,7 +1270,7 @@ .getResult(0); } - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); } }; @@ -1400,8 +1407,8 @@ rewriter.eraseOp(forOp); // Replace the warpOp result coming from the original ForOp. for (const auto &res : llvm::enumerate(resultIdx)) { - newWarpOp.getResult(res.value()) - .replaceAllUsesWith(newForOp.getResult(res.index())); + rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), + newForOp.getResult(res.index())); newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); } newForOp.walk([&](Operation *op) { @@ -1494,7 +1501,7 @@ rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, newWarpOp.getResult(newRetIndices[1])); } - newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce); + rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); return success(); }