diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/SetVector.h" namespace mlir { +class RewriterBase; /// Check if all values in the provided range are defined above the `limit` /// region. That is, if they are defined in a region that is a proper ancestor @@ -53,8 +54,10 @@ /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any -/// of the regions were simplified, failure otherwise. -LogicalResult simplifyRegions(MutableArrayRef regions); +/// of the regions were simplified, failure otherwise. The provided rewriter is +/// used to notify callers of operation and block deletion. +LogicalResult simplifyRegions(RewriterBase &rewriter, + MutableArrayRef regions); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -114,7 +114,7 @@ // TODO: This is based on the fact that zero use operations // may be deleted, and that single use values often have more // canonicalization opportunities. - if (!operand.use_empty() && !operand.hasOneUse()) + if (!operand || (!operand.use_empty() && !operand.hasOneUse())) continue; if (auto *defInst = operand.getDefiningOp()) addToWorklist(defInst); @@ -202,10 +202,7 @@ // After applying patterns, make sure that the CFG of each of the regions is // kept up to date. - if (succeeded(simplifyRegions(regions))) { - folder.clear(); - changed = true; - } + changed |= succeeded(simplifyRegions(*this, regions)); } while (changed && ++i < maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -75,7 +76,8 @@ /// Erase the unreachable blocks within the provided regions. Returns success /// if any blocks were erased, failure otherwise. // TODO: We could likely merge this with the DCE algorithm below. -static LogicalResult eraseUnreachableBlocks(MutableArrayRef regions) { +static LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter, + MutableArrayRef regions) { // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set reachable; // If any blocks were found to be dead. @@ -108,7 +110,7 @@ for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { block.dropAllDefinedValueUses(); - block.erase(); + rewriter.eraseBlock(&block); erasedDeadBlocks = true; continue; } @@ -305,7 +307,8 @@ } } -static LogicalResult deleteDeadness(MutableArrayRef regions, +static LogicalResult deleteDeadness(RewriterBase &rewriter, + MutableArrayRef regions, LiveMap &liveMap) { bool erasedAnything = false; for (Region ®ion : regions) { @@ -324,10 +327,10 @@ if (!liveMap.wasProvenLive(&childOp)) { erasedAnything = true; childOp.dropAllUses(); - childOp.erase(); + rewriter.eraseOp(&childOp); } else { - erasedAnything |= - succeeded(deleteDeadness(childOp.getRegions(), liveMap)); + erasedAnything |= succeeded( + deleteDeadness(rewriter, childOp.getRegions(), liveMap)); } } } @@ -359,7 +362,8 @@ // // This function returns success if any operations or arguments were deleted, // failure otherwise. -static LogicalResult runRegionDCE(MutableArrayRef regions) { +static LogicalResult runRegionDCE(RewriterBase &rewriter, + MutableArrayRef regions) { LiveMap liveMap; do { liveMap.resetChanged(); @@ -368,7 +372,7 @@ propagateLiveness(region, liveMap); } while (liveMap.hasChanged()); - return deleteDeadness(regions, liveMap); + return deleteDeadness(rewriter, regions, liveMap); } //===----------------------------------------------------------------------===// @@ -456,7 +460,7 @@ LogicalResult addToCluster(BlockEquivalenceData &blockData); /// Try to merge all of the blocks within this cluster into the leader block. - LogicalResult merge(); + LogicalResult merge(RewriterBase &rewriter); private: /// The equivalence data for the leader of the cluster. @@ -550,7 +554,7 @@ return true; } -LogicalResult BlockMergeCluster::merge() { +LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { // Don't consider clusters that don't have blocks to merge. if (blocksToMerge.empty()) return failure(); @@ -613,7 +617,7 @@ // Replace all uses of the merged blocks with the leader and erase them. for (Block *block : blocksToMerge) { block->replaceAllUsesWith(leaderBlock); - block->erase(); + rewriter.eraseBlock(block); } return success(); } @@ -621,7 +625,8 @@ /// Identify identical blocks within the given region and merge them, inserting /// new block arguments as necessary. Returns success if any blocks were merged, /// failure otherwise. -static LogicalResult mergeIdenticalBlocks(Region ®ion) { +static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, + Region ®ion) { if (region.empty() || llvm::hasSingleElement(region)) return failure(); @@ -659,7 +664,7 @@ clusters.emplace_back(std::move(data)); } for (auto &cluster : clusters) - mergedAnyBlocks |= succeeded(cluster.merge()); + mergedAnyBlocks |= succeeded(cluster.merge(rewriter)); } return success(mergedAnyBlocks); @@ -667,14 +672,15 @@ /// Identify identical blocks within the given regions and merge them, inserting /// new block arguments as necessary. -static LogicalResult mergeIdenticalBlocks(MutableArrayRef regions) { +static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, + MutableArrayRef regions) { llvm::SmallSetVector worklist; for (auto ®ion : regions) worklist.insert(®ion); bool anyChanged = false; while (!worklist.empty()) { Region *region = worklist.pop_back_val(); - if (succeeded(mergeIdenticalBlocks(*region))) { + if (succeeded(mergeIdenticalBlocks(rewriter, *region))) { worklist.insert(region); anyChanged = true; } @@ -697,10 +703,12 @@ /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. -LogicalResult mlir::simplifyRegions(MutableArrayRef regions) { - bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions)); - bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions)); - bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions)); +LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, + MutableArrayRef regions) { + bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); + bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); + bool mergedIdenticalBlocks = + succeeded(mergeIdenticalBlocks(rewriter, regions)); return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -21,12 +21,12 @@ // CHECK-LABEL: func @single_iteration( // CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK: [[C42:%.*]] = constant 42 : i32 // CHECK: [[C0:%.*]] = constant 0 : index // CHECK: [[C2:%.*]] = constant 2 : index // CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[C6:%.*]] = constant 6 : index // CHECK: [[C7:%.*]] = constant 7 : index -// CHECK: [[C42:%.*]] = constant 42 : i32 // CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) { // CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref // CHECK: scf.yield