diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -6,7 +6,7 @@ include "mlir/Dialect/LLVMIR/LLVMEnums.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/Mem2RegInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" // Operations that correspond to LLVM intrinsics. With MLIR operation set being // extendable, there is no reason to introduce a hard boundary between "core" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -22,7 +22,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/Mem2RegInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" class LLVM_Builder { diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -16,7 +16,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -15,7 +15,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/Mem2RegInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -7,7 +7,6 @@ add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) -add_mlir_interface(Mem2RegInterfaces) add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) @@ -17,6 +16,12 @@ add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) +set(LLVM_TARGET_DEFINITIONS MemorySlotInterfaces.td) +mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS DataLayoutInterfaces.td) mlir_tablegen(DataLayoutAttrInterface.h.inc -gen-attr-interface-decls) mlir_tablegen(DataLayoutAttrInterface.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h rename from mlir/include/mlir/Interfaces/Mem2RegInterfaces.h rename to mlir/include/mlir/Interfaces/MemorySlotInterfaces.h --- a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_INTERFACES_MEM2REGINTERFACES_H -#define MLIR_INTERFACES_MEM2REGINTERFACES_H +#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES_H +#define MLIR_INTERFACES_MEMORYSLOTINTERFACES_H #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" @@ -34,6 +34,6 @@ } // namespace mlir -#include "mlir/Interfaces/Mem2RegInterfaces.h.inc" +#include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc" -#endif // MLIR_INTERFACES_MEM2REGINTERFACES_H +#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td rename from mlir/include/mlir/Interfaces/Mem2RegInterfaces.td rename to mlir/include/mlir/Interfaces/MemorySlotInterfaces.td --- a/mlir/include/mlir/Interfaces/Mem2RegInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -1,4 +1,4 @@ -//===-- Mem2RegInterfaces.td - Mem2Reg interfaces ----------*- tablegen -*-===// +//===-- MemorySlotInterfaces.td - MemorySlot interfaces ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_INTERFACES_MEM2REGINTERFACES -#define MLIR_INTERFACES_MEM2REGINTERFACES +#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES +#define MLIR_INTERFACES_MEMORYSLOTINTERFACES include "mlir/IR/OpBase.td" @@ -76,6 +76,9 @@ to memory slots. Loads and stores must be of whole values of the same type as the slot itself. + For a memory operation on a slot to be valid, it must operate on the slot + pointer *only as a pointer to an element of the type of the slot*. + If the same operation does both loads and stores on the same slot, the load must semantically happen first. }]; @@ -152,21 +155,21 @@ let methods = [ InterfaceMethod<[{ Checks that this operation can be promoted to no longer use the provided - blocking uses, in the context of promoting `slot`. + blocking uses, in order to allow optimization. If the removal procedure of the use will require that other uses get removed, that dependency should be added to the `newBlockingUses` argument. Dependent uses must only be uses of results of this operation. }], "bool", "canUsesBeRemoved", - (ins "const ::mlir::MemorySlot &":$slot, - "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, + (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses, "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses) >, InterfaceMethod<[{ Transforms IR to ensure that the current operation does not use the - provided memory slot anymore. In contrast to `PromotableMemOpInterface`, - operations implementing this interface must not need access to the - reaching definition of the content of the slot. + provided blocking uses anymore. In contrast to + `PromotableMemOpInterface`, operations implementing this interface + must not need access to the reaching definition of the content of the + slot. During the transformation, *no operation should be deleted*. The operation can only schedule its own deletion by returning the @@ -186,11 +189,10 @@ }], "::mlir::DeletionKind", "removeBlockingUses", - (ins "const ::mlir::MemorySlot &":$slot, - "const ::llvm::SmallPtrSetImpl &":$blockingUses, + (ins "const ::llvm::SmallPtrSetImpl &":$blockingUses, "::mlir::OpBuilder &":$builder) >, ]; } -#endif // MLIR_INTERFACES_MEM2REGINTERFACES +#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h --- a/mlir/include/mlir/Transforms/Mem2Reg.h +++ b/mlir/include/mlir/Transforms/Mem2Reg.h @@ -11,10 +11,112 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" namespace mlir { +/// Information computed during promotion analysis used to perform actual +/// promotion. +struct MemorySlotPromotionInfo { + /// Blocks for which at least two definitions of the slot values clash. + SmallPtrSet mergePoints; + /// Contains, for each operation, which uses must be eliminated by promotion. + /// This is a DAG structure because if an operation must eliminate some of + /// its uses, it is because the defining ops of the blocking uses requested + /// it. The defining ops therefore must also have blocking uses or be the + /// starting point of the bloccking uses. + DenseMap> userToBlockingUses; +}; + +/// Computes information for basic slot promotion. This will check that direct +/// slot promotion can be performed, and provide the information to execute the +/// promotion. This does not mutate IR. +class MemorySlotPromotionAnalyzer { +public: + MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance) + : slot(slot), dominance(dominance) {} + + /// Computes the information for slot promotion if promotion is possible, + /// returns nothing otherwise. + std::optional computeInfo(); + +private: + /// Computes the transitive uses of the slot that block promotion. This finds + /// uses that would block the promotion, checks that the operation has a + /// solution to remove the blocking use, and potentially forwards the analysis + /// if the operation needs further blocking uses resolved to resolve its own + /// uses (typically, removing its users because it will delete itself to + /// resolve its own blocking uses). This will fail if one of the transitive + /// users cannot remove a requested use, and should prevent promotion. + LogicalResult computeBlockingUses( + DenseMap> &userToBlockingUses); + + /// Computes in which blocks the value stored in the slot is actually used, + /// meaning blocks leading to a load. This method uses `definingBlocks`, the + /// set of blocks containing a store to the slot (defining the value of the + /// slot). + SmallPtrSet + computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); + + /// Computes the points in which multiple re-definitions of the slot's value + /// (stores) may conflict. + void computeMergePoints(SmallPtrSetImpl &mergePoints); + + /// Ensures predecessors of merge points can properly provide their current + /// definition of the value stored in the slot to the merge point. This can + /// notably be an issue if the terminator used does not have the ability to + /// forward values through block operands. + bool areMergePointsUsable(SmallPtrSetImpl &mergePoints); + + MemorySlot slot; + DominanceInfo &dominance; +}; + +/// The MemorySlotPromoter handles the state of promoting a memory slot. It +/// wraps a slot and its associated allocator. This will perform the mutation of +/// IR. +class MemorySlotPromoter { +public: + MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, + OpBuilder &builder, DominanceInfo &dominance, + MemorySlotPromotionInfo info); + + /// Actually promotes the slot by mutating IR. Promoting a slot does not + /// invalidate the MemorySlotPromotionInfo of other slots. + void promoteSlot(); + +private: + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` is the value the slot should contain at the + /// beginning of the block. This method returns the reached definition at the + /// end of the block. + Value computeReachingDefInBlock(Block *block, Value reachingDef); + + /// Computes the reaching definition for all the operations that require + /// promotion. `reachingDef` corresponds to the initial value the + /// slot will contain before any write, typically a poison value. + void computeReachingDefInRegion(Region *region, Value reachingDef); + + /// Removes the blocking uses of the slot, in topological order. + void removeBlockingUses(); + + /// Lazily-constructed default value representing the content of the slot when + /// no store has been executed. This function may mutate IR. + Value getLazyDefaultValue(); + + MemorySlot slot; + PromotableAllocationOpInterface allocator; + OpBuilder &builder; + /// Potentially non-initialized default value. Use `getLazyDefaultValue` to + /// initialize it on demand. + Value defaultValue; + /// Contains the reaching definition at this operation. Reaching definitions + /// are only computed for promotable memory operations with blocking uses. + DenseMap reachingDefs; + DominanceInfo &dominance; + MemorySlotPromotionInfo info; +}; + /// Attempts to promote the memory slots of the provided allocators. Succeeds if /// at least one memory slot was promoted. LogicalResult diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -6,7 +6,7 @@ IR/LLVMDialect.cpp IR/LLVMInlining.cpp IR/LLVMInterfaces.cpp - IR/LLVMMem2Reg.cpp + IR/LLVMMemorySlot.cpp IR/LLVMTypes.cpp IR/LLVMTypeSyntax.cpp @@ -34,7 +34,7 @@ MLIRDataLayoutInterfaces MLIRInferTypeOpInterface MLIRIR - MLIRMem2RegInterfaces + MLIRMemorySlotInterfaces MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp rename from mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp rename to mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1,4 +1,4 @@ -//===- LLVMMem2Reg.cpp - Mem2Reg Interfaces ---------------------*- C++ -*-===// +//===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,13 @@ // //===----------------------------------------------------------------------===// // -// This file implements Mem2Reg-related interfaces for LLVM dialect operations. +// This file implements MemorySlot-related interfaces for LLVM dialect +// operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" using namespace mlir; @@ -19,13 +20,18 @@ // Interfaces for AllocaOp //===----------------------------------------------------------------------===// +static Type getAllocaElementType(LLVM::AllocaOp alloca) { + // This will become trivial once non-opaque pointers are gone. + return alloca.getElemType().has_value() + ? *alloca.getElemType() + : alloca.getResult().getType().getElementType(); +} + llvm::SmallVector LLVM::AllocaOp::getPromotableSlots() { if (!getOperation()->getBlock()->isEntryBlock()) return {}; - Type elemType = - getElemType() ? *getElemType() : getResult().getType().getElementType(); - return {MemorySlot{getResult(), elemType}}; + return {MemorySlot{getResult(), getAllocaElementType(*this)}}; } Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, @@ -128,62 +134,57 @@ } bool LLVM::BitcastOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return forwardToUsers(*this, newBlockingUses); } DeletionKind LLVM::BitcastOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::AddrSpaceCastOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return forwardToUsers(*this, newBlockingUses); } DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::LifetimeStartOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::LifetimeStartOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::LifetimeEndOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::LifetimeEndOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } bool LLVM::DbgDeclareOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { return true; } DeletionKind LLVM::DbgDeclareOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } @@ -194,8 +195,12 @@ }); } +//===----------------------------------------------------------------------===// +// Interfaces for GEPOp +//===----------------------------------------------------------------------===// + bool LLVM::GEPOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { // GEP can be removed as long as it is a no-op and its users can be removed. if (!hasAllZeroIndices(*this)) @@ -204,7 +209,6 @@ } DeletionKind LLVM::GEPOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - OpBuilder &builder) { + const SmallPtrSetImpl &blockingUses, OpBuilder &builder) { return DeletionKind::Delete; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Interfaces/Mem2RegInterfaces.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,7 +9,7 @@ InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp - Mem2RegInterfaces.cpp + MemorySlotInterfaces.cpp ParallelCombiningOpInterface.cpp RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp @@ -46,7 +46,7 @@ add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) -add_mlir_interface_library(Mem2RegInterfaces) +add_mlir_interface_library(MemorySlotInterfaces) add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) diff --git a/mlir/lib/Interfaces/Mem2RegInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp rename from mlir/lib/Interfaces/Mem2RegInterfaces.cpp rename to mlir/lib/Interfaces/MemorySlotInterfaces.cpp --- a/mlir/lib/Interfaces/Mem2RegInterfaces.cpp +++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp @@ -1,4 +1,4 @@ -//===-- Mem2RegInterfaces.cpp - Mem2Reg interfaces --------------*- C++ -*-===// +//===-- MemorySlotInterfaces.cpp - MemorySlot interfaces --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,6 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" -#include "mlir/Interfaces/Mem2RegInterfaces.cpp.inc" +#include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -28,7 +28,7 @@ MLIRAnalysis MLIRCopyOpInterface MLIRLoopLikeInterface - MLIRMem2RegInterfaces + MLIRMemorySlotInterfaces MLIRPass MLIRRuntimeVerifiableOpInterface MLIRSideEffectInterfaces diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -9,9 +9,12 @@ #include "mlir/Transforms/Mem2Reg.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/GenericIteratedDominanceFrontier.h" namespace mlir { @@ -42,7 +45,10 @@ /// this, the value stored can be well defined at block boundaries, allowing /// the propagation of replacement through blocks. /// -/// This pass computes this transformation in four main steps: +/// This pass computes this transformation in four main steps. The two first +/// steps are performed during an analysis phase that does not mutate IR. +/// +/// The two steps of the analysis phase are the following: /// - A first step computes the list of operations that transitively use the /// memory slot we would like to promote. The purpose of this phase is to /// identify which uses must be removed to promote the slot, either by rewiring @@ -60,6 +66,9 @@ /// existing. Computing this information in advance allows making sure the /// terminators that will forward values are capable of doing so (inability to /// do so aborts promotion at this step). +/// +/// At this point, promotion is guaranteed to happen, and the mutation phase can +/// begin with the following steps: /// - A third step computes the reaching definition of the memory slot at each /// blocking user. This is the core of the mem2reg algorithm, also known as /// load-store forwarding. This analyses loads and stores and propagates which @@ -73,10 +82,6 @@ /// - The final fourth step uses the reaching definition to remove blocking uses /// in topological order. /// -/// The two first steps do not mutate IR because promotion can still be aborted -/// at this point. Once the two last steps are reached, promotion is guaranteed -/// to succeed, allowing to start mutating IR. -/// /// For further reading, chapter three of SSA-based Compiler Design [1] /// showcases SSA construction, where mem2reg is an adaptation of the same /// process. @@ -84,100 +89,11 @@ /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022), /// Springer. -namespace { - -/// The SlotPromoter handles the state of promoting a memory slot. It wraps a -/// slot and its associated allocator, along with analysis results related to -/// the slot. -class SlotPromoter { -public: - SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance); - - /// Prepare data for the promotion of the slot while checking if it can be - /// promoted. Succeeds if the slot can be promoted. This method does not - /// mutate IR. - LogicalResult prepareSlotPromotion(); - - /// Actually promotes the slot by mutating IR. This method must only be - /// called after a successful call to `SlotPromoter::prepareSlotPromotion`. - /// Promoting a slot does not invalidate the preparation of other slots. - void promoteSlot(); - -private: - /// This is the first step of the promotion algorithm. - /// Computes the transitive uses of the slot that block promotion. This finds - /// uses that would block the promotion, checks that the operation has a - /// solution to remove the blocking use, and potentially forwards the analysis - /// if the operation needs further blocking uses resolved to resolve its own - /// uses (typically, removing its users because it will delete itself to - /// resolve its own blocking uses). This will fail if one of the transitive - /// users cannot remove a requested use, and should prevent promotion. - LogicalResult computeBlockingUses(); - - /// Computes in which blocks the value stored in the slot is actually used, - /// meaning blocks leading to a load. This method uses `definingBlocks`, the - /// set of blocks containing a store to the slot (defining the value of the - /// slot). - SmallPtrSet - computeSlotLiveIn(SmallPtrSetImpl &definingBlocks); - - /// This is the second step of the promotion algorithm. - /// Computes the points in which multiple re-definitions of the slot's value - /// (stores) may conflict. - void computeMergePoints(); - - /// Ensures predecessors of merge points can properly provide their current - /// definition of the value stored in the slot to the merge point. This can - /// notably be an issue if the terminator used does not have the ability to - /// forward values through block operands. - bool areMergePointsUsable(); - - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` is the value the slot should contain at the - /// beginning of the block. This method returns the reached definition at the - /// end of the block. - Value computeReachingDefInBlock(Block *block, Value reachingDef); - - /// This is the third step of the promotion algorithm. - /// Computes the reaching definition for all the operations that require - /// promotion. `reachingDef` corresponds to the initial value the - /// slot will contain before any write, typically a poison value. - void computeReachingDefInRegion(Region *region, Value reachingDef); - - /// This is the fourth step of the promotion algorithm. - /// Removes the blocking uses of the slot, in topological order. - void removeBlockingUses(); - - /// Lazily-constructed default value representing the content of the slot when - /// no store has been executed. This function may mutate IR. - Value getLazyDefaultValue(); - - MemorySlot slot; - PromotableAllocationOpInterface allocator; - OpBuilder &builder; - /// Potentially non-initialized default value. Use `lazyDefaultValue` to - /// initialize it on demand. - Value defaultValue; - /// Blocks where multiple definitions of the slot value clash. - SmallPtrSet mergePoints; - /// Contains, for each operation, which uses must be eliminated by promotion. - /// This is a DAG structure because an operation that must eliminate some of - /// its uses always comes from a request from an operation that must - /// eliminate some of its own uses. - DenseMap> userToBlockingUses; - /// Contains the reaching definition at this operation. Reaching definitions - /// are only computed for promotable memory operations with blocking uses. - DenseMap reachingDefs; - DominanceInfo &dominance; -}; - -} // namespace - -SlotPromoter::SlotPromoter(MemorySlot slot, - PromotableAllocationOpInterface allocator, - OpBuilder &builder, DominanceInfo &dominance) - : slot(slot), allocator(allocator), builder(builder), dominance(dominance) { +MemorySlotPromoter::MemorySlotPromoter( + MemorySlot slot, PromotableAllocationOpInterface allocator, + OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info) + : slot(slot), allocator(allocator), builder(builder), dominance(dominance), + info(std::move(info)) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { if (BlockArgument arg = slot.ptr.dyn_cast()) @@ -191,7 +107,7 @@ #endif // NDEBUG } -Value SlotPromoter::getLazyDefaultValue() { +Value MemorySlotPromoter::getLazyDefaultValue() { if (defaultValue) return defaultValue; @@ -200,7 +116,8 @@ return defaultValue = allocator.getDefaultValue(slot, builder); } -LogicalResult SlotPromoter::computeBlockingUses() { +LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses( + DenseMap> &userToBlockingUses) { // The promotion of an operation may require the promotion of further // operations (typically, removing operations that use an operation that must // delete itself). We thus need to start from the use of the slot pointer and @@ -216,7 +133,7 @@ // Then, propagate the requirements for the removal of uses. The // topologically-sorted forward slice allows for all blocking uses of an - // operation to have been computed before we reach it. Operations are + // operation to have been computed before it is reached. Operations are // traversed in topological order of their uses, starting from the slot // pointer. SetVector forwardSlice; @@ -232,7 +149,7 @@ // If the operation decides it cannot deal with removing the blocking uses, // promotion must fail. if (auto promotable = dyn_cast(user)) { - if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses)) return failure(); } else if (auto promotable = dyn_cast(user)) { if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses)) @@ -254,9 +171,9 @@ } // Because this pass currently only supports analysing the parent region of - // the slot pointer, if a promotable memory op that needs promotion is - // outside of this region, promotion must fail because it will be impossible - // to provide a valid `reachingDef` for it. + // the slot pointer, if a promotable memory op that needs promotion is outside + // of this region, promotion must fail because it will be impossible to + // provide a valid `reachingDef` for it. for (auto &[toPromote, _] : userToBlockingUses) if (isa(toPromote) && toPromote->getParentRegion() != slot.ptr.getParentRegion()) @@ -265,8 +182,8 @@ return success(); } -SmallPtrSet -SlotPromoter::computeSlotLiveIn(SmallPtrSetImpl &definingBlocks) { +SmallPtrSet MemorySlotPromotionAnalyzer::computeSlotLiveIn( + SmallPtrSetImpl &definingBlocks) { SmallPtrSet liveIn; // The worklist contains blocks in which it is known that the slot value is @@ -323,7 +240,8 @@ } using IDFCalculator = llvm::IDFCalculatorBase; -void SlotPromoter::computeMergePoints() { +void MemorySlotPromotionAnalyzer::computeMergePoints( + SmallPtrSetImpl &mergePoints) { if (slot.ptr.getParentRegion()->hasOneBlock()) return; @@ -346,7 +264,8 @@ mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end()); } -bool SlotPromoter::areMergePointsUsable() { +bool MemorySlotPromotionAnalyzer::areMergePointsUsable( + SmallPtrSetImpl &mergePoints) { for (Block *mergePoint : mergePoints) for (Block *pred : mergePoint->getPredecessors()) if (!isa(pred->getTerminator())) @@ -355,10 +274,36 @@ return true; } -Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { +std::optional +MemorySlotPromotionAnalyzer::computeInfo() { + MemorySlotPromotionInfo info; + + // First, find the set of operations that will need to be changed for the + // promotion to happen. These operations need to resolve some of their uses, + // either by rewiring them or simply deleting themselves. If any of them + // cannot find a way to resolve their blocking uses, we abort the promotion. + if (failed(computeBlockingUses(info.userToBlockingUses))) + return {}; + + // Then, compute blocks in which two or more definitions of the allocated + // variable may conflict. These blocks will need a new block argument to + // accomodate this. + computeMergePoints(info.mergePoints); + + // The slot can be promoted if the block arguments to be created can + // actually be populated with values, which may not be possible depending + // on their predecessors. + if (!areMergePointsUsable(info.mergePoints)) + return {}; + + return info; +} + +Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, + Value reachingDef) { for (Operation &op : block->getOperations()) { if (auto memOp = dyn_cast(op)) { - if (userToBlockingUses.contains(memOp)) + if (info.userToBlockingUses.contains(memOp)) reachingDefs.insert({memOp, reachingDef}); if (Value stored = memOp.getStored(slot)) @@ -369,8 +314,8 @@ return reachingDef; } -void SlotPromoter::computeReachingDefInRegion(Region *region, - Value reachingDef) { +void MemorySlotPromoter::computeReachingDefInRegion(Region *region, + Value reachingDef) { if (region->hasOneBlock()) { computeReachingDefInBlock(®ion->front(), reachingDef); return; @@ -392,7 +337,7 @@ DfsJob job = dfsStack.pop_back_val(); Block *block = job.block->getBlock(); - if (mergePoints.contains(block)) { + if (info.mergePoints.contains(block)) { BlockArgument blockArgument = block->addArgument(slot.elemType, slot.ptr.getLoc()); builder.setInsertionPointToStart(block); @@ -404,7 +349,7 @@ if (auto terminator = dyn_cast(block->getTerminator())) { for (BlockOperand &blockOperand : terminator->getBlockOperands()) { - if (mergePoints.contains(blockOperand.get())) { + if (info.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); terminator.getSuccessorOperands(blockOperand.getOperandNumber()) @@ -418,9 +363,9 @@ } } -void SlotPromoter::removeBlockingUses() { +void MemorySlotPromoter::removeBlockingUses() { llvm::SetVector usersToRemoveUses; - for (auto &user : llvm::make_first_range(userToBlockingUses)) + for (auto &user : llvm::make_first_range(info.userToBlockingUses)) usersToRemoveUses.insert(user); SetVector sortedUsersToRemoveUses = mlir::topologicalSort(usersToRemoveUses); @@ -435,8 +380,8 @@ reachingDef = getLazyDefaultValue(); builder.setInsertionPointAfter(toPromote); - if (toPromoteMemOp.removeBlockingUses(slot, userToBlockingUses[toPromote], - builder, reachingDef) == + if (toPromoteMemOp.removeBlockingUses( + slot, info.userToBlockingUses[toPromote], builder, reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); @@ -445,7 +390,7 @@ auto toPromoteBasic = cast(toPromote); builder.setInsertionPointAfter(toPromote); - if (toPromoteBasic.removeBlockingUses(slot, userToBlockingUses[toPromote], + if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], builder) == DeletionKind::Delete) toErase.push_back(toPromote); } @@ -457,7 +402,7 @@ "after promotion, the slot pointer should not be used anymore"); } -void SlotPromoter::promoteSlot() { +void MemorySlotPromoter::promoteSlot() { computeReachingDefInRegion(slot.ptr.getParentRegion(), {}); // Now that reaching definitions are known, remove all users. @@ -465,7 +410,7 @@ // Update terminators in dead branches to forward default if they are // succeeded by a merge points. - for (Block *mergePoint : mergePoints) { + for (Block *mergePoint : info.mergePoints) { for (BlockOperand &use : mergePoint->getUses()) { auto user = cast(use.getOwner()); SuccessorOperands succOperands = @@ -480,43 +425,26 @@ allocator.handlePromotionComplete(slot, defaultValue); } -LogicalResult SlotPromoter::prepareSlotPromotion() { - // First, find the set of operations that will need to be changed for the - // promotion to happen. These operations need to resolve some of their uses, - // either by rewiring them or simply deleting themselves. If any of them - // cannot find a way to resolve their blocking uses, we abort the promotion. - if (failed(computeBlockingUses())) - return failure(); - - // Then, compute blocks in which two or more definitions of the allocated - // variable may conflict. These blocks will need a new block argument to - // accomodate this. - computeMergePoints(); - - // The slot can be promoted if the block arguments to be created can - // actually be populated with values, which may not be possible depending - // on their predecessors. - return success(areMergePointsUsable()); -} - LogicalResult mlir::tryToPromoteMemorySlots( ArrayRef allocators, OpBuilder &builder, DominanceInfo &dominance) { // Actual promotion may invalidate the dominance analysis, so slot promotion // is prepated in batches. - SmallVector toPromote; + SmallVector toPromote; for (PromotableAllocationOpInterface allocator : allocators) { for (MemorySlot slot : allocator.getPromotableSlots()) { if (slot.ptr.use_empty()) continue; - SlotPromoter promoter(slot, allocator, builder, dominance); - if (succeeded(promoter.prepareSlotPromotion())) - toPromote.emplace_back(std::move(promoter)); + MemorySlotPromotionAnalyzer analyzer(slot, dominance); + std::optional info = analyzer.computeInfo(); + if (info) + toPromote.emplace_back(slot, allocator, builder, dominance, + std::move(*info)); } } - for (SlotPromoter &promoter : toPromote) + for (MemorySlotPromoter &promoter : toPromote) promoter.promoteSlot(); return success(!toPromote.empty());