diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -11,6 +11,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -31,6 +31,8 @@ Promotion of the slot will lead to the slot pointer no longer being used, leaving the content of the memory slot unreachable. + + No IR mutation is allowed in this method. }], "::llvm::SmallVector<::mlir::MemorySlot>", "getPromotableSlots", (ins) >, @@ -38,34 +40,42 @@ Provides the default Value of this memory slot. The provided Value will be used as the reaching definition of loads done before any store. This Value must outlive the promotion and dominate all the uses of this - slot's pointer. The provided builder can be used to create the default + slot's pointer. The provided rewriter can be used to create the default value on the fly. - The builder is located at the beginning of the block where the slot - pointer is defined. + The rewriter is located at the beginning of the block where the slot + pointer is defined. All IR mutations must happen through the rewriter. }], "::mlir::Value", "getDefaultValue", - (ins "const ::mlir::MemorySlot &":$slot, "::mlir::OpBuilder &":$builder) + (ins + "const ::mlir::MemorySlot &":$slot, + "::mlir::RewriterBase &":$rewriter) >, InterfaceMethod<[{ Hook triggered for every new block argument added to a block. This will only be called for slots declared by this operation. - The builder is located at the beginning of the block on call. + The rewriter is located at the beginning of the block on call. All IR + mutations must happen through the rewriter. }], "void", "handleBlockArgument", (ins "const ::mlir::MemorySlot &":$slot, "::mlir::BlockArgument":$argument, - "::mlir::OpBuilder &":$builder + "::mlir::RewriterBase &":$rewriter ) >, InterfaceMethod<[{ Hook triggered once the promotion of a slot is complete. This can also clean up the created default value if necessary. This will only be called for slots declared by this operation. + + All IR mutations must happen through the rewriter. }], "void", "handlePromotionComplete", - (ins "const ::mlir::MemorySlot &":$slot, "::mlir::Value":$defaultValue) + (ins + "const ::mlir::MemorySlot &":$slot, + "::mlir::Value":$defaultValue, + "::mlir::RewriterBase &":$rewriter) >, ]; } @@ -87,6 +97,8 @@ let methods = [ InterfaceMethod<[{ Gets whether this operation loads from the specified slot. + + No IR mutation is allowed in this method. }], "bool", "loadsFrom", (ins "const ::mlir::MemorySlot &":$slot) @@ -96,6 +108,8 @@ value if this operation does not store to this slot. An operation storing a value to a slot must always be able to provide the value it stores. This method is only called on operations that use the slot. + + No IR mutation is allowed in this method. }], "::mlir::Value", "getStored", (ins "const ::mlir::MemorySlot &":$slot) @@ -107,6 +121,8 @@ If the removal procedure of the use will require that other uses get removed, that dependency should be added to the `newBlockingUses` argument. Dependent uses must only be uses of results of this operation. + + No IR mutation is allowed in this method. }], "bool", "canUsesBeRemoved", (ins "const ::mlir::MemorySlot &":$slot, "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, @@ -132,13 +148,14 @@ have been done at the point of calling this method, but it will be done eventually. - The builder is located after the promotable operation on call. + The rewriter is located after the promotable operation on call. All IR + mutations must happen through the rewriter. }], "::mlir::DeletionKind", "removeBlockingUses", (ins "const ::mlir::MemorySlot &":$slot, "const ::llvm::SmallPtrSetImpl &":$blockingUses, - "::mlir::OpBuilder &":$builder, + "::mlir::RewriterBase &":$rewriter, "::mlir::Value":$reachingDefinition) >, ]; @@ -160,6 +177,8 @@ If the removal procedure of the use will require that other uses get removed, that dependency should be added to the `newBlockingUses` argument. Dependent uses must only be uses of results of this operation. + + No IR mutation is allowed in this method. }], "bool", "canUsesBeRemoved", (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses) @@ -185,12 +204,13 @@ have been done at the point of calling this method, but it will be done eventually. - The builder is located after the promotable operation on call. + The rewriter is located after the promotable operation on call. All IR + mutations must happen through the rewriter. }], "::mlir::DeletionKind", "removeBlockingUses", (ins "const ::llvm::SmallPtrSetImpl &":$blockingUses, - "::mlir::OpBuilder &":$builder) + "::mlir::RewriterBase &":$rewriter) >, ]; } diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -13,129 +13,39 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "llvm/ADT/Statistic.h" namespace mlir { -/// Information computed during promotion analysis used to perform actual -/// promotion. -struct MemorySlotPromotionInfo { - /// Blocks for which at least two definitions of the slot values clash. - SmallPtrSet mergePoints; - /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because if an operation must eliminate some of - /// its uses, it is because the defining ops of the blocking uses requested - /// it. The defining ops therefore must also have blocking uses or be the - /// starting point of the bloccking uses. - DenseMap> userToBlockingUses; -}; - -/// Computes information for basic slot promotion. This will check that direct -/// slot promotion can be performed, and provide the information to execute the -/// promotion. This does not mutate IR. -class MemorySlotPromotionAnalyzer { -public: - MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) - : slot(slot), dominance(dominance) {} - - /// Computes the information for slot promotion if promotion is possible, - /// returns nothing otherwise. - std::optional computeInfo(); - -private: - /// Computes the transitive uses of the slot that block promotion. This finds - /// uses that would block the promotion, checks that the operation has a - /// solution to remove the blocking use, and potentially forwards the analysis - /// if the operation needs further blocking uses resolved to resolve its own - /// uses (typically, removing its users because it will delete itself to - /// resolve its own blocking uses). This will fail if one of the transitive - /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses( - DenseMap> &userToBlockingUses); - - /// Computes in which blocks the value stored in the slot is actually used, - /// meaning blocks leading to a load. This method uses `definingBlocks`, the - /// set of blocks containing a store to the slot (defining the value of the - /// slot). - SmallPtrSet - computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); - - /// Computes the points in which multiple re-definitions of the slot's value - /// (stores) may conflict. - void computeMergePoints(SmallPtrSetImpl &mergePoints); - - /// Ensures predecessors of merge points can properly provide their current - /// definition of the value stored in the slot to the merge point. This can - /// notably be an issue if the terminator used does not have the ability to - /// forward values through block operands. - bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); - - MemorySlot slot; - DominanceInfo &dominance; -}; - -/// The MemorySlotPromoter handles the state of promoting a memory slot. It -/// wraps a slot and its associated allocator. This will perform the mutation of -/// IR. -class MemorySlotPromoter { -public: - MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance, - MemorySlotPromotionInfo info); - - /// Actually promotes the slot by mutating IR. Promoting a slot does not - /// invalidate the MemorySlotPromotionInfo of other slots. - void promoteSlot(); - -private: - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` is the value the slot should contain at the - /// beginning of the block. This method returns the reached definition at the - /// end of the block. - Value computeReachingDefInBlock(Block *block, Value reachingDef); - - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` corresponds to the initial value the - /// slot will contain before any write, typically a poison value. - void computeReachingDefInRegion(Region *region, Value reachingDef); - - /// Removes the blocking uses of the slot, in topological order. - void removeBlockingUses(); - - /// Lazily-constructed default value representing the content of the slot when - /// no store has been executed. This function may mutate IR. - Value getLazyDefaultValue(); - - MemorySlot slot; - PromotableAllocationOpInterface allocator; - OpBuilder &builder; - /// Potentially non-initialized default value. Use `getLazyDefaultValue` to - /// initialize it on demand. - Value defaultValue; - /// Contains the reaching definition at this operation. Reaching definitions - /// are only computed for promotable memory operations with blocking uses. - DenseMap reachingDefs; - DominanceInfo &dominance; - MemorySlotPromotionInfo info; +struct Mem2RegStatistics { + llvm::Statistic *promotedAmount = nullptr; + llvm::Statistic *newBlockArgumentAmount = nullptr; }; /// Pattern applying mem2reg to the regions of the operations on which it /// matches. -class Mem2RegPattern : public RewritePattern { +class Mem2RegPattern + : public OpInterfaceRewritePattern { public: - using RewritePattern::RewritePattern; + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {} + Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {}, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {} - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator, PatternRewriter &rewriter) const override; + +private: + Mem2RegStatistics statistics; }; /// Attempts to promote the memory slots of the provided allocators. Succeeds if /// at least one memory slot was promoted. LogicalResult tryToPromoteMemorySlots(ArrayRef allocators, - OpBuilder &builder, DominanceInfo &dominance); + RewriterBase &rewriter, + Mem2RegStatistics statistics = {}); } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -189,6 +189,21 @@ This pass only supports unstructured control-flow. Promotion of operations within subregions will not happen. }]; + + let options = [ + Option<"enableRegionSimplification", "region-simplify", "bool", + /*default=*/"true", + "Perform control flow optimizations to the region tree">, + ]; + + let statistics = [ + Statistic<"promotedAmount", + "promoted slots", + "Number of promoted memory slot">, + Statistic<"newBlockArgumentAmount", + "new block args", + "Total number of block arguments added">, + ]; } def PrintOpStats : Pass<"print-op-stats"> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -35,24 +35,25 @@ } Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, - OpBuilder &builder) { - return builder.create(getLoc(), slot.elemType); + RewriterBase &rewriter) { + return rewriter.create(getLoc(), slot.elemType); } void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument, - OpBuilder &builder) { + RewriterBase &rewriter) { for (Operation *user : getOperation()->getUsers()) if (auto declareOp = llvm::dyn_cast(user)) - builder.create(declareOp.getLoc(), argument, - declareOp.getVarInfo()); + rewriter.create(declareOp.getLoc(), argument, + declareOp.getVarInfo()); } void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, - Value defaultValue) { + Value defaultValue, + RewriterBase &rewriter) { if (defaultValue && defaultValue.use_empty()) - defaultValue.getDefiningOp()->erase(); - erase(); + rewriter.eraseOp(defaultValue.getDefiningOp()); + rewriter.eraseOp(*this); } //===----------------------------------------------------------------------===// @@ -87,10 +88,10 @@ DeletionKind LLVM::LoadOp::removeBlockingUses( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder, Value reachingDefinition) { + RewriterBase &rewriter, Value reachingDefinition) { // `canUsesBeRemoved` checked this blocking use must be the loaded slot // pointer. - getResult().replaceAllUsesWith(reachingDefinition); + rewriter.replaceAllUsesWith(getResult(), reachingDefinition); return DeletionKind::Delete; } @@ -110,13 +111,13 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder, Value reachingDefinition) { + RewriterBase &rewriter, Value reachingDefinition) { // `canUsesBeRemoved` checked this blocking use must be the stored slot // pointer. for (Operation *user : slot.ptr.getUsers()) if (auto declareOp = llvm::dyn_cast(user)) - builder.create(declareOp->getLoc(), getValue(), - declareOp.getVarInfo()); + rewriter.create(declareOp->getLoc(), getValue(), + declareOp.getVarInfo()); return DeletionKind::Delete; } @@ -140,7 +141,7 @@ } DeletionKind LLVM::BitcastOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } @@ -151,7 +152,7 @@ } DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } @@ -162,7 +163,7 @@ } DeletionKind LLVM::LifetimeStartOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } @@ -173,7 +174,7 @@ } DeletionKind LLVM::LifetimeEndOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } @@ -184,7 +185,7 @@ } DeletionKind LLVM::DbgDeclareOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } @@ -209,6 +210,6 @@ } DeletionKind LLVM::GEPOp::removeBlockingUses( - const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, RewriterBase &rewriter) { return DeletionKind::Delete; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp @@ -40,29 +40,30 @@ } Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, - OpBuilder &builder) { + RewriterBase &rewriter) { assert(isSupportedElementType(slot.elemType)); // TODO: support more types. return TypeSwitch(slot.elemType) .Case([&](MemRefType t) { - return builder.create(getLoc(), t); + return rewriter.create(getLoc(), t); }) .Default([&](Type t) { - return builder.create(getLoc(), t, - builder.getZeroAttr(t)); + return rewriter.create(getLoc(), t, + rewriter.getZeroAttr(t)); }); } void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, - Value defaultValue) { + Value defaultValue, + RewriterBase &rewriter) { if (defaultValue.use_empty()) - defaultValue.getDefiningOp()->erase(); - erase(); + rewriter.eraseOp(defaultValue.getDefiningOp()); + rewriter.eraseOp(*this); } void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument, - OpBuilder &builder) {} + RewriterBase &rewriter) {} //===----------------------------------------------------------------------===// // LoadOp/StoreOp interfaces @@ -86,10 +87,10 @@ DeletionKind memref::LoadOp::removeBlockingUses( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder, Value reachingDefinition) { + RewriterBase &rewriter, Value reachingDefinition) { // `canUsesBeRemoved` checked this blocking use must be the loaded slot // pointer. - getResult().replaceAllUsesWith(reachingDefinition); + rewriter.replaceAllUsesWith(getResult(), reachingDefinition); return DeletionKind::Delete; } @@ -113,6 +114,6 @@ DeletionKind memref::StoreOp::removeBlockingUses( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder, Value reachingDefinition) { + RewriterBase &rewriter, Value reachingDefinition) { return DeletionKind::Delete; } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -10,6 +10,8 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -92,11 +94,121 @@ /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. +namespace { + +/// Information computed during promotion analysis used to perform actual +/// promotion. +struct MemorySlotPromotionInfo { + /// Blocks for which at least two definitions of the slot values clash. + SmallPtrSet mergePoints; + /// Contains, for each operation, which uses must be eliminated by promotion. + /// This is a DAG structure because if an operation must eliminate some of + /// its uses, it is because the defining ops of the blocking uses requested + /// it. The defining ops therefore must also have blocking uses or be the + /// starting point of the bloccking uses. + DenseMap> userToBlockingUses; +}; + +/// Computes information for basic slot promotion. This will check that direct +/// slot promotion can be performed, and provide the information to execute the +/// promotion. This does not mutate IR. +class MemorySlotPromotionAnalyzer { +public: + MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} + + /// Computes the information for slot promotion if promotion is possible, + /// returns nothing otherwise. + std::optional computeInfo(); + +private: + /// Computes the transitive uses of the slot that block promotion. This finds + /// uses that would block the promotion, checks that the operation has a + /// solution to remove the blocking use, and potentially forwards the analysis + /// if the operation needs further blocking uses resolved to resolve its own + /// uses (typically, removing its users because it will delete itself to + /// resolve its own blocking uses). This will fail if one of the transitive + /// users cannot remove a requested use, and should prevent promotion. + LogicalResult computeBlockingUses( + DenseMap> &userToBlockingUses); + + /// Computes in which blocks the value stored in the slot is actually used, + /// meaning blocks leading to a load. This method uses `definingBlocks`, the + /// set of blocks containing a store to the slot (defining the value of the + /// slot). + SmallPtrSet + computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); + + /// Computes the points in which multiple re-definitions of the slot's value + /// (stores) may conflict. + void computeMergePoints(SmallPtrSetImpl &mergePoints); + + /// Ensures predecessors of merge points can properly provide their current + /// definition of the value stored in the slot to the merge point. This can + /// notably be an issue if the terminator used does not have the ability to + /// forward values through block operands. + bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); + + MemorySlot slot; + DominanceInfo &dominance; +}; + +/// The MemorySlotPromoter handles the state of promoting a memory slot. It +/// wraps a slot and its associated allocator. This will perform the mutation of +/// IR. +class MemorySlotPromoter { +public: + MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, + RewriterBase &rewriter, DominanceInfo &dominance, + MemorySlotPromotionInfo info, + const Mem2RegStatistics &statistics); + + /// Actually promotes the slot by mutating IR. Promoting a slot DOES + /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of + /// promotion info should NOT be performed in batches. + void promoteSlot(); + +private: + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` is the value the slot should contain at the + /// beginning of the block. This method returns the reached definition at the + /// end of the block. + Value computeReachingDefInBlock(Block *block, Value reachingDef); + + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` corresponds to the initial value the + /// slot will contain before any write, typically a poison value. + void computeReachingDefInRegion(Region *region, Value reachingDef); + + /// Removes the blocking uses of the slot, in topological order. + void removeBlockingUses(); + + /// Lazily-constructed default value representing the content of the slot when + /// no store has been executed. This function may mutate IR. + Value getLazyDefaultValue(); + + MemorySlot slot; + PromotableAllocationOpInterface allocator; + RewriterBase &rewriter; + /// Potentially non-initialized default value. Use `getLazyDefaultValue` to + /// initialize it on demand. + Value defaultValue; + /// Contains the reaching definition at this operation. Reaching definitions + /// are only computed for promotable memory operations with blocking uses. + DenseMap reachingDefs; + DominanceInfo &dominance; + MemorySlotPromotionInfo info; + const Mem2RegStatistics &statistics; +}; + +} // namespace + MemorySlotPromoter::MemorySlotPromoter( MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info) - : slot(slot), allocator(allocator), builder(builder), dominance(dominance), - info(std::move(info)) { + RewriterBase &rewriter, DominanceInfo &dominance, + MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics) + : slot(slot), allocator(allocator), rewriter(rewriter), + dominance(dominance), info(std::move(info)), statistics(statistics) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { if (BlockArgument arg = dyn_cast(slot.ptr)) @@ -114,9 +226,9 @@ if (defaultValue) return defaultValue; - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPointToStart(slot.ptr.getParentBlock()); - return defaultValue = allocator.getDefaultValue(slot, builder); + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(slot.ptr.getParentBlock()); + return defaultValue = allocator.getDefaultValue(slot, rewriter); } LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( @@ -341,11 +453,37 @@ Block *block = job.block->getBlock(); if (info.mergePoints.contains(block)) { - BlockArgument blockArgument = - block->addArgument(slot.elemType, slot.ptr.getLoc()); - builder.setInsertionPointToStart(block); - allocator.handleBlockArgument(slot, blockArgument, builder); + // If the block is a merge point, we need to add a block argument to hold + // the selected reaching definition. This has to be a bit complicated + // because of RewriterBase limitations: we need to create a new block with + // the extra block argument, move the content of the block to the new + // block, and replace the block with the new block in the merge point set. + SmallVector argTypes; + SmallVector argLocs; + for (BlockArgument arg : block->getArguments()) { + argTypes.push_back(arg.getType()); + argLocs.push_back(arg.getLoc()); + } + argTypes.push_back(slot.elemType); + argLocs.push_back(slot.ptr.getLoc()); + Block *newBlock = rewriter.createBlock(block, argTypes, argLocs); + + info.mergePoints.erase(block); + info.mergePoints.insert(newBlock); + + rewriter.replaceAllUsesWith(block, newBlock); + rewriter.mergeBlocks(block, newBlock, + newBlock->getArguments().drop_back()); + + block = newBlock; + + BlockArgument blockArgument = block->getArguments().back(); + rewriter.setInsertionPointToStart(block); + allocator.handleBlockArgument(slot, blockArgument, rewriter); job.reachingDef = blockArgument; + + if (statistics.newBlockArgumentAmount) + (*statistics.newBlockArgumentAmount)++; } job.reachingDef = computeReachingDefInBlock(block, job.reachingDef); @@ -355,8 +493,10 @@ if (info.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); - terminator.getSuccessorOperands(blockOperand.getOperandNumber()) - .append(job.reachingDef); + rewriter.updateRootInPlace(terminator, [&]() { + terminator.getSuccessorOperands(blockOperand.getOperandNumber()) + .append(job.reachingDef); + }); } } } @@ -382,24 +522,24 @@ if (!reachingDef) reachingDef = getLazyDefaultValue(); - builder.setInsertionPointAfter(toPromote); + rewriter.setInsertionPointAfter(toPromote); if (toPromoteMemOp.removeBlockingUses( - slot, info.userToBlockingUses[toPromote], builder, reachingDef) == - DeletionKind::Delete) + slot, info.userToBlockingUses[toPromote], rewriter, + reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); continue; } auto toPromoteBasic = cast(toPromote); - builder.setInsertionPointAfter(toPromote); + rewriter.setInsertionPointAfter(toPromote); if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], - builder) == DeletionKind::Delete) + rewriter) == DeletionKind::Delete) toErase.push_back(toPromote); } for (Operation *toEraseOp : toErase) - toEraseOp->erase(); + rewriter.eraseOp(toEraseOp); assert(slot.ptr.use_empty() && "after promotion, the slot pointer should not be used anymore"); @@ -421,87 +561,73 @@ assert(succOperands.size() == mergePoint->getNumArguments() || succOperands.size() + 1 == mergePoint->getNumArguments()); if (succOperands.size() + 1 == mergePoint->getNumArguments()) - succOperands.append(getLazyDefaultValue()); + rewriter.updateRootInPlace( + user, [&]() { succOperands.append(getLazyDefaultValue()); }); } } LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr << "\n"); - allocator.handlePromotionComplete(slot, defaultValue); + if (statistics.promotedAmount) + (*statistics.promotedAmount)++; + + allocator.handlePromotionComplete(slot, defaultValue, rewriter); } LogicalResult mlir::tryToPromoteMemorySlots( - ArrayRef allocators, OpBuilder &builder, - DominanceInfo &dominance) { - // Actual promotion may invalidate the dominance analysis, so slot promotion - // is prepated in batches. - SmallVector toPromote; + ArrayRef allocators, + RewriterBase &rewriter, Mem2RegStatistics statistics) { + DominanceInfo dominance; + + bool promotedAny = false; + for (PromotableAllocationOpInterface allocator : allocators) { for (MemorySlot slot : allocator.getPromotableSlots()) { if (slot.ptr.use_empty()) continue; + DominanceInfo dominance; MemorySlotPromotionAnalyzer analyzer(slot, dominance); std::optional info = analyzer.computeInfo(); - if (info) - toPromote.emplace_back(slot, allocator, builder, dominance, - std::move(*info)); + if (info) { + MemorySlotPromoter(slot, allocator, rewriter, dominance, + std::move(*info), statistics) + .promoteSlot(); + promotedAny = true; + } } } - for (MemorySlotPromoter &promoter : toPromote) - promoter.promoteSlot(); - - return success(!toPromote.empty()); + return success(promotedAny); } -LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { +LogicalResult +Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator, + PatternRewriter &rewriter) const { hasBoundedRewriteRecursion(); - - if (op->getNumRegions() == 0) - return failure(); - - DominanceInfo dominance; - - SmallVector allocators; - // Build a list of allocators to attempt to promote the slots of. - for (Region ®ion : op->getRegions()) - for (auto allocator : region.getOps()) - allocators.emplace_back(allocator); - - // Because pattern rewriters are normally not expressive enough to support a - // transformation like mem2reg, this uses an escape hatch to mark modified - // operations manually and operate outside of its context. - rewriter.startRootUpdate(op); - - OpBuilder builder(rewriter.getContext()); - - if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) { - rewriter.cancelRootUpdate(op); - return failure(); - } - - rewriter.finalizeRootUpdate(op); - return success(); + return tryToPromoteMemorySlots({allocator}, rewriter, statistics); } namespace { struct Mem2Reg : impl::Mem2RegBase { + using impl::Mem2RegBase::Mem2RegBase; + void runOnOperation() override { Operation *scopeOp = getOperation(); - bool changed = false; + + Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount}; + + GreedyRewriteConfig config; + config.enableRegionSimplification = enableRegionSimplification; RewritePatternSet rewritePatterns(&getContext()); - rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext(), statictics); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); - (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(), - &changed); - if (!changed) - markAllAnalysesPreserved(); + if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config))) + signalPassFailure(); } }; diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir --- a/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg))' | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg{region-simplify=false}))' | FileCheck %s llvm.func @use(i64) llvm.func @use_ptr(!llvm.ptr) diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir --- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @default_value llvm.func @default_value() -> i32 { diff --git a/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file --mlir-pass-statistics 2>&1 >/dev/null | FileCheck %s + +// CHECK: Mem2Reg +// CHECK-NEXT: (S) 0 new block args +// CHECK-NEXT: (S) 1 promoted slots +func.func @basic() -> i32 { + %0 = arith.constant 5 : i32 + %1 = memref.alloca() : memref + memref.store %0, %1[] : memref + %2 = memref.load %1[] : memref + return %2 : i32 +} + +// ----- + +// CHECK: Mem2Reg +// CHECK-NEXT: (S) 0 new block args +// CHECK-NEXT: (S) 0 promoted slots +func.func @no_alloca() -> i32 { + %0 = arith.constant 5 : i32 + return %0 : i32 +} + +// ----- + +// CHECK: Mem2Reg +// CHECK-NEXT: (S) 2 new block args +// CHECK-NEXT: (S) 1 promoted slots +func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) { + %alloca = memref.alloca() : memref + memref.store %arg2, %alloca[] : memref + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1: + %use = memref.load %alloca[] : memref + call @use(%use) : (i64) -> () + memref.store %arg0, %alloca[] : memref + cf.br ^bb2 +^bb2: + cf.br ^bb1 +} + +func.func @use(%arg: i64) { return } + +// ----- + +// CHECK: Mem2Reg +// CHECK-NEXT: (S) 0 new block args +// CHECK-NEXT: (S) 3 promoted slots +func.func @recursive(%arg: i64) -> i64 { + %alloca0 = memref.alloca() : memref>> + %alloca1 = memref.alloca() : memref> + %alloca2 = memref.alloca() : memref + memref.store %arg, %alloca2[] : memref + memref.store %alloca2, %alloca1[] : memref> + memref.store %alloca1, %alloca0[] : memref>> + %load0 = memref.load %alloca0[] : memref>> + %load1 = memref.load %load0[] : memref> + %load2 = memref.load %load1[] : memref + return %load2 : i64 +} diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir --- a/mlir/test/Dialect/MemRef/mem2reg.mlir +++ b/mlir/test/Dialect/MemRef/mem2reg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg{region-simplify=false}))' --split-input-file | FileCheck %s // CHECK-LABEL: func.func @basic func.func @basic() -> i32 { @@ -148,20 +148,18 @@ // CHECK-LABEL: func.func @promotable_nonpromotable_intertwined func.func @promotable_nonpromotable_intertwined() -> i32 { - // CHECK: %[[VAL:.*]] = arith.constant 5 : i32 - %0 = arith.constant 5 : i32 // CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref - %1 = memref.alloca() : memref + %0 = memref.alloca() : memref // CHECK-NOT: = memref.alloca() : memref> - %2 = memref.alloca() : memref> - memref.store %1, %2[] : memref> - %3 = memref.load %2[] : memref> + %1 = memref.alloca() : memref> + memref.store %0, %1[] : memref> + %2 = memref.load %1[] : memref> // CHECK: call @use(%[[NON_PROMOTED]]) - call @use(%1) : (memref) -> () + call @use(%0) : (memref) -> () // CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][] - %4 = memref.load %1[] : memref + %3 = memref.load %0[] : memref // CHECK: return %[[RES]] : i32 - return %4 : i32 + return %3 : i32 } func.func @use(%arg: memref) { return }