diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -11,6 +11,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" namespace mlir { @@ -116,6 +117,19 @@ MemorySlotPromotionInfo info; }; +/// Pattern applying mem2reg to the regions of the operations on which it +/// matches. +class Mem2RegPattern : public RewritePattern { +public: + using RewritePattern::RewritePattern; + + Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + /// Attempts to promote the memory slots of the provided allocators. Succeeds if /// at least one memory slot was promoted. LogicalResult diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -22,6 +23,8 @@ #include "mlir/Transforms/Passes.h.inc" } // namespace mlir +#define DEBUG_TYPE "mem2reg" + using namespace mlir; /// mem2reg @@ -421,6 +424,9 @@ } } + LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr + << "\n"); + allocator.handlePromotionComplete(slot, defaultValue); } @@ -449,39 +455,49 @@ return success(!toPromote.empty()); } -namespace { +LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + hasBoundedRewriteRecursion(); -struct Mem2Reg : impl::Mem2RegBase { - void runOnOperation() override { - Operation *scopeOp = getOperation(); - bool changed = false; + if (op->getNumRegions() == 0) + return failure(); - for (Region ®ion : scopeOp->getRegions()) { - if (region.getBlocks().empty()) - continue; + DominanceInfo dominance; - OpBuilder builder(®ion.front(), region.front().begin()); + SmallVector allocators; + // Build a list of allocators to attempt to promote the slots of. + for (Region ®ion : op->getRegions()) + for (auto allocator : region.getOps()) + allocators.emplace_back(allocator); - // Promoting a slot can allow for further promotion of other slots, - // promotion is tried until no promotion succeeds. - while (true) { - DominanceInfo &dominance = getAnalysis(); + // Because pattern rewriters are normally not expressive enough to support a + // transformation like mem2reg, this uses an escape hatch to mark modified + // operations manually and operate outside of its context. + rewriter.startRootUpdate(op); - SmallVector allocators; - // Build a list of allocators to attempt to promote the slots of. - for (Block &block : region) - for (Operation &op : block.getOperations()) - if (auto allocator = dyn_cast(op)) - allocators.emplace_back(allocator); + OpBuilder builder(rewriter.getContext()); - // Attempt promoting until no promotion succeeds. - if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) - break; + if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) { + rewriter.cancelRootUpdate(op); + return failure(); + } - changed = true; - getAnalysisManager().invalidate({}); - } - } + rewriter.finalizeRootUpdate(op); + return success(); +} + +namespace { + +struct Mem2Reg : impl::Mem2RegBase { + void runOnOperation() override { + Operation *scopeOp = getOperation(); + bool changed = false; + + RewritePatternSet rewritePatterns(&getContext()); + rewritePatterns.add(&getContext()); + FrozenRewritePatternSet frozen(std::move(rewritePatterns)); + (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(), + &changed); if (!changed) markAllAnalysesPreserved();