diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -166,6 +166,11 @@ replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl &exceptions) const; + /// Replace all uses of 'this' value with 'newValue', updating anything in the + /// IR that uses 'this' to use the other value instead except if the user is + /// 'exceptedUser'. + void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const; + /// Replace all uses of 'this' value with 'newValue' if the given callback /// returns true. void replaceUsesWithIf(Value newValue, diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp @@ -72,7 +72,7 @@ applyOperands.push_back(iv); applyOperands.append(symbolOperands.begin(), symbolOperands.end()); auto apply = builder.create(op.getLoc(), map, applyOperands); - iv.replaceAllUsesExcept(apply, SmallPtrSet{apply}); + iv.replaceAllUsesExcept(apply, apply); } SmallVector newSteps(op.getNumDims(), 1); @@ -181,8 +181,7 @@ AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols(), newIVExpr); Operation *newIV = opBuilder.create(loc, ivMap, lbOperands); - op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), - SmallPtrSet{newIV}); + op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -191,8 +191,7 @@ AffineApplyOp applyOp = builder.create( indexOp.getLoc(), index + offset, ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset}); - indexOp.getResult().replaceAllUsesExcept( - applyOp, SmallPtrSet{applyOp}); + indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -155,8 +155,7 @@ AffineApplyOp applyOp = b.create( indexOp.getLoc(), index + iv, ValueRange{indexOp.getResult(), ivs[rangeIndex->second]}); - indexOp.getResult().replaceAllUsesExcept( - applyOp.getResult(), SmallPtrSet{applyOp}); + indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); } } diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -121,8 +121,7 @@ Value inner_index = std::get<0>(ivs); AddIOp newIndex = b.create(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs)); - inner_index.replaceAllUsesExcept( - newIndex, SmallPtrSet{newIndex.getOperation()}); + inner_index.replaceAllUsesExcept(newIndex, newIndex); } op.erase(); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -63,12 +63,23 @@ /// listed in 'exceptions' . void Value::replaceAllUsesExcept( Value newValue, const SmallPtrSetImpl &exceptions) const { - for (auto &use : llvm::make_early_inc_range(getUses())) { + for (OpOperand &use : llvm::make_early_inc_range(getUses())) { if (exceptions.count(use.getOwner()) == 0) use.set(newValue); } } +/// Replace all uses of 'this' value with 'newValue', updating anything in the +/// IR that uses 'this' to use the other value instead except if the user is +/// 'exceptedUser'. +void Value::replaceAllUsesExcept(Value newValue, + Operation *exceptedUser) const { + for (OpOperand &use : llvm::make_early_inc_range(getUses())) { + if (use.getOwner() != exceptedUser) + use.set(newValue); + } +} + /// Replace all uses of 'this' value with 'newValue' if the given callback /// returns true. void Value::replaceUsesWithIf(Value newValue,