diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -444,7 +444,7 @@ auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create(loc, denseTp, op.getSrc()); - op->setOperand(0, convert); + rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); }); return success(); } if (encDst) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -546,7 +546,7 @@ forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); rewriter.setInsertionPointToStart(forOpNew.getBody()); } else { - forOp.setStep(step); + rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); }); rewriter.setInsertionPoint(yield); } vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), @@ -575,10 +575,11 @@ // Now do some relinking (last one is not completely type safe // but all bad ones are removed right away). This also folds away // nop broadcast operations. - forOp.getResult(0).replaceAllUsesWith(vres); - forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); - forOp.getRegionIterArg(0).replaceAllUsesWith( - forOpNew.getRegionIterArg(0)); + rewriter.replaceAllUsesWith(forOp.getResult(0), vres); + rewriter.replaceAllUsesWith(forOp.getInductionVar(), + forOpNew.getInductionVar()); + rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), + forOpNew.getRegionIterArg(0)); rewriter.eraseOp(forOp); } return true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -838,9 +838,12 @@ if (auto indexOp = dyn_cast(def)) return genIndexValue(env, indexOp.getDim()); if (def->getBlock() == block) { - for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) - def->setOperand( - i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); + for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { + rewriter.updateRootInPlace(def, [&]() { + def->setOperand( + i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); + }); + } } } return e; @@ -1615,7 +1618,8 @@ auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - env.op()->setOperand(tensor, convert); + rewriter.updateRootInPlace( + env.op(), [&]() { env.op()->setOperand(tensor, convert); }); rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success();