diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -138,6 +138,13 @@ /// there are zero uses of 'this'. void replaceAllUsesWith(Value newValue) const; + /// Replace all uses of 'this' value with 'newValue', updating anything in the + /// IR that uses 'this' to use the other value instead except if the user is + /// listed in 'exceptions' . + void + replaceAllUsesExcept(Value newValue, + const SmallPtrSetImpl &exceptions) const; + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -293,11 +293,6 @@ separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); -/// Replaces all uses of `orig` with `replacement` except if the user is listed -/// in `exceptions`. -void replaceAllUsesExcept(Value orig, Value replacement, - const SmallPtrSetImpl &exceptions); - } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -110,11 +109,10 @@ b.setInsertionPointToStart(&block); for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { Value oldIndex = block.getArgument(i); - Value newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - loopRanges[i].offset); - replaceAllUsesExcept( - oldIndex, newIndex, - SmallPtrSet{newIndex.getDefiningOp()}); + AddIOp newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, + loopRanges[i].offset); + oldIndex.replaceAllUsesExcept(newIndex, + SmallPtrSet{newIndex}); } } return clonedOp; diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; /// Construct a value. @@ -121,6 +122,17 @@ useList->replaceAllUsesWith(*this, newValue); } +/// Replace all uses of 'this' value with the new value, updating anything in +/// the IR that uses 'this' to use the other value instead except if the user is +/// listed in 'exceptions' . +void Value::replaceAllUsesExcept( + Value newValue, const SmallPtrSetImpl &exceptions) const { + for (auto &use : llvm::make_early_inc_range(getUses())) { + if (exceptions.count(use.getOwner()) == 0) + use.set(newValue); + } +} + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -29,7 +29,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -1212,7 +1211,7 @@ SmallPtrSet preserve{scaled.getDefiningOp(), shifted.getDefiningOp()}; - replaceAllUsesExcept(inductionVar, shifted, preserve); + inductionVar.replaceAllUsesExcept(shifted, preserve); return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound, /*step=*/newStep}; } @@ -2379,12 +2378,3 @@ return success(); } - -void mlir::replaceAllUsesExcept( - Value orig, Value replacement, - const SmallPtrSetImpl &exceptions) { - for (auto &use : llvm::make_early_inc_range(orig.getUses())) { - if (exceptions.count(use.getOwner()) == 0) - use.set(replacement); - } -}