diff --git a/llvm/include/llvm/ADT/GenericCycleImpl.h b/llvm/include/llvm/ADT/GenericCycleImpl.h --- a/llvm/include/llvm/ADT/GenericCycleImpl.h +++ b/llvm/include/llvm/ADT/GenericCycleImpl.h @@ -354,11 +354,11 @@ template void GenericCycleInfo::compute(FunctionT &F) { GenericCycleInfoCompute Compute(*this); - Context.setFunction(F); + Context = ContextT(&F); LLVM_DEBUG(errs() << "Computing cycles for function: " << F.getName() << "\n"); - Compute.run(ContextT::getEntryBlock(F)); + Compute.run(&F.front()); assert(validateTree()); } diff --git a/llvm/include/llvm/ADT/GenericCycleInfo.h b/llvm/include/llvm/ADT/GenericCycleInfo.h --- a/llvm/include/llvm/ADT/GenericCycleInfo.h +++ b/llvm/include/llvm/ADT/GenericCycleInfo.h @@ -256,7 +256,7 @@ void clear(); void compute(FunctionT &F); - FunctionT *getFunction() const { return Context.getFunction(); } + const FunctionT *getFunction() const { return Context.getFunction(); } const ContextT &getSSAContext() const { return Context; } CycleT *getCycle(const BlockT *Block) const; diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h --- a/llvm/include/llvm/ADT/GenericSSAContext.h +++ b/llvm/include/llvm/ADT/GenericSSAContext.h @@ -21,87 +21,78 @@ namespace llvm { +template class DominatorTreeBase; +template class SmallVectorImpl; + +// Specializations of this template should provide the types used by the +// template GenericSSAContext below. +template struct GenericSSATraits; + +// Ideally this should have been a stateless traits class. But the print methods +// for Machine IR need access to the owning function. So we track that state in +// the template itself. +// +// We use FunctionT as a template argument and not GenericSSATraits to allow +// forward declarations using well-known typenames. template class GenericSSAContext { -public: - // Specializations should provide the following types that are similar to how - // LLVM IR is structured: + using SSATraits = GenericSSATraits<_FunctionT>; + const typename SSATraits::FunctionT *F; +public: // The smallest unit of the IR is a ValueT. The SSA context uses a ValueRefT, // which is a pointer to a ValueT, since Machine IR does not have the // equivalent of a ValueT. - // - // using ValueRefT = ... - // + using ValueRefT = typename SSATraits::ValueRefT; + // The ConstValueRefT is needed to work with "const Value *", where const // needs to bind to the pointee and not the pointer. - // - // using ConstValueRefT = ... - // - // The null value for ValueRefT. - // - // static constexpr ValueRefT ValueRefNull; + using ConstValueRefT = typename SSATraits::ConstValueRefT; + + // The null value for ValueRefT. For LLVM IR and MIR, this is simply the + // default constructed value. + static constexpr ValueRefT *ValueRefNull = {}; // An InstructionT usually defines one or more ValueT objects. - // - // using InstructionT = ... must be a subclass of Value + using InstructionT = typename SSATraits::InstructionT; // A UseT represents a data-edge from the defining instruction to the using // instruction. - // - // using UseT = ... + using UseT = typename SSATraits::UseT; // A BlockT is a sequence of InstructionT, and forms a node of the CFG. It // has global methods predecessors() and successors() that return // the list of incoming CFG edges and outgoing CFG edges // respectively. - // - // using BlockT = ... + using BlockT = typename SSATraits::BlockT; // A FunctionT represents a CFG along with arguments and return values. It is // the smallest complete unit of code in a Module. - // - // The compiler produces an error here if this class is implicitly - // specialized due to an instantiation. An explicit specialization - // of this template needs to be added before the instantiation point - // indicated by the compiler. - using FunctionT = typename _FunctionT::invalidTemplateInstanceError; + using FunctionT = typename SSATraits::FunctionT; // A dominator tree provides the dominance relation between basic blocks in // a given funciton. - // - // using DominatorTreeT = ... - - // Initialize the SSA context with information about the FunctionT being - // processed. - // - // void setFunction(FunctionT &function); - // FunctionT* getFunction() const; - - // Every FunctionT has a unique BlockT marked as its entry. - // - // static BlockT* getEntryBlock(FunctionT &F); - - // Methods to examine basic blocks and values - // - // static void appendBlockDefs(SmallVectorImpl &defs, - // BlockT &block); - // static void appendBlockDefs(SmallVectorImpl &defs, - // const BlockT &block); - - // static void appendBlockTerms(SmallVectorImpl &terms, - // BlockT &block); - // static void appendBlockTerms(SmallVectorImpl &terms, - // const BlockT &block); - // - // static bool comesBefore(const InstructionT *lhs, const InstructionT *rhs); - // static bool isConstantOrUndefValuePhi(const InstructionT &Instr); - // const BlockT *getDefBlock(const ValueRefT value) const; - - // Methods to print various objects. - // - // Printable print(BlockT *block) const; - // Printable print(InstructionT *inst) const; - // Printable print(ValueRefT value) const; + using DominatorTreeT = DominatorTreeBase; + + GenericSSAContext() = default; + GenericSSAContext(const FunctionT *F) : F(F) {} + + const FunctionT *getFunction() const { return F; } + + static void appendBlockDefs(SmallVectorImpl &defs, BlockT &block); + static void appendBlockDefs(SmallVectorImpl &defs, + const BlockT &block); + + static void appendBlockTerms(SmallVectorImpl &terms, + BlockT &block); + static void appendBlockTerms(SmallVectorImpl &terms, + const BlockT &block); + + static bool isConstantOrUndefValuePhi(const InstructionT &Instr); + const BlockT *getDefBlock(ConstValueRefT value) const; + + Printable print(const BlockT *block) const; + Printable print(const InstructionT *inst) const; + Printable print(ConstValueRefT value) const; }; } // namespace llvm diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h --- a/llvm/include/llvm/ADT/GenericUniformityImpl.h +++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h @@ -129,11 +129,11 @@ const ContextT &Context; void computeCyclePO(const CycleInfoT &CI, const CycleT *Cycle, - SmallPtrSetImpl &Finalized); + SmallPtrSetImpl &Finalized); - void computeStackPO(SmallVectorImpl &Stack, const CycleInfoT &CI, - const CycleT *Cycle, - SmallPtrSetImpl &Finalized); + void computeStackPO(SmallVectorImpl &Stack, + const CycleInfoT &CI, const CycleT *Cycle, + SmallPtrSetImpl &Finalized); }; template class DivergencePropagator; @@ -342,11 +342,10 @@ typename SyncDependenceAnalysisT::DivergenceDescriptor; using BlockLabelMapT = typename SyncDependenceAnalysisT::BlockLabelMap; - GenericUniformityAnalysisImpl(const FunctionT &F, const DominatorTreeT &DT, - const CycleInfoT &CI, + GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI, const TargetTransformInfo *TTI) - : Context(CI.getSSAContext()), F(F), CI(CI), TTI(TTI), DT(DT), - SDA(Context, DT, CI) {} + : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI), + TTI(TTI), DT(DT), SDA(Context, DT, CI) {} void initialize(); @@ -1135,10 +1134,9 @@ template GenericUniformityInfo::GenericUniformityInfo( - FunctionT &Func, const DominatorTreeT &DT, const CycleInfoT &CI, - const TargetTransformInfo *TTI) - : F(&Func) { - DA.reset(new ImplT{Func, DT, CI, TTI}); + const DominatorTreeT &DT, const CycleInfoT &CI, + const TargetTransformInfo *TTI) { + DA.reset(new ImplT{DT, CI, TTI}); } template @@ -1214,6 +1212,12 @@ return DA->hasDivergence(); } +template +const typename ContextT::FunctionT & +GenericUniformityInfo::getFunction() const { + return DA->getFunction(); +} + /// Whether \p V is divergent at its definition. template bool GenericUniformityInfo::isDivergent(ConstValueRefT V) const { @@ -1243,8 +1247,8 @@ template void llvm::ModifiedPostOrder::computeStackPO( - SmallVectorImpl &Stack, const CycleInfoT &CI, const CycleT *Cycle, - SmallPtrSetImpl &Finalized) { + SmallVectorImpl &Stack, const CycleInfoT &CI, + const CycleT *Cycle, SmallPtrSetImpl &Finalized) { LLVM_DEBUG(dbgs() << "inside computeStackPO\n"); while (!Stack.empty()) { auto *NextBB = Stack.back(); @@ -1313,9 +1317,9 @@ template void ModifiedPostOrder::computeCyclePO( const CycleInfoT &CI, const CycleT *Cycle, - SmallPtrSetImpl &Finalized) { + SmallPtrSetImpl &Finalized) { LLVM_DEBUG(dbgs() << "inside computeCyclePO\n"); - SmallVector Stack; + SmallVector Stack; auto *CycleHeader = Cycle->getHeader(); LLVM_DEBUG(dbgs() << " noted header: " @@ -1352,11 +1356,11 @@ /// \brief Generically compute the modified post order. template void llvm::ModifiedPostOrder::compute(const CycleInfoT &CI) { - SmallPtrSet Finalized; - SmallVector Stack; + SmallPtrSet Finalized; + SmallVector Stack; auto *F = CI.getFunction(); Stack.reserve(24); // FIXME made-up number - Stack.push_back(GraphTraits::getEntryNode(F)); + Stack.push_back(&F->front()); computeStackPO(Stack, CI, nullptr, Finalized); } diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h --- a/llvm/include/llvm/ADT/GenericUniformityInfo.h +++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h @@ -40,8 +40,7 @@ using CycleInfoT = GenericCycleInfo; using CycleT = typename CycleInfoT::CycleT; - GenericUniformityInfo(FunctionT &F, const DominatorTreeT &DT, - const CycleInfoT &CI, + GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI, const TargetTransformInfo *TTI = nullptr); GenericUniformityInfo() = default; GenericUniformityInfo(GenericUniformityInfo &&) = default; @@ -56,7 +55,7 @@ bool hasDivergence() const; /// The GPU kernel this analysis result is for - const FunctionT &getFunction() const { return *F; } + const FunctionT &getFunction() const; /// Whether \p V is divergent at its definition. bool isDivergent(ConstValueRefT V) const; @@ -82,7 +81,6 @@ private: using ImplT = GenericUniformityAnalysisImpl; - FunctionT *F; std::unique_ptr> DA; GenericUniformityInfo(const GenericUniformityInfo &) = delete; diff --git a/llvm/include/llvm/CodeGen/MachineSSAContext.h b/llvm/include/llvm/CodeGen/MachineSSAContext.h --- a/llvm/include/llvm/CodeGen/MachineSSAContext.h +++ b/llvm/include/llvm/CodeGen/MachineSSAContext.h @@ -15,6 +15,7 @@ #ifndef LLVM_CODEGEN_MACHINESSACONTEXT_H #define LLVM_CODEGEN_MACHINESSACONTEXT_H +#include "llvm/ADT/GenericSSAContext.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/Support/Printable.h" @@ -23,8 +24,6 @@ class MachineInstr; class MachineFunction; class Register; -template class GenericSSAContext; -template class DominatorTreeBase; inline unsigned succ_size(const MachineBasicBlock *BB) { return BB->succ_size(); @@ -34,37 +33,13 @@ } inline auto instrs(const MachineBasicBlock &BB) { return BB.instrs(); } -template <> class GenericSSAContext { - const MachineRegisterInfo *RegInfo = nullptr; - MachineFunction *MF = nullptr; - -public: +template <> struct GenericSSATraits { using BlockT = MachineBasicBlock; using FunctionT = MachineFunction; using InstructionT = MachineInstr; using ValueRefT = Register; using ConstValueRefT = Register; using UseT = MachineOperand; - using DominatorTreeT = DominatorTreeBase; - - static constexpr Register ValueRefNull = 0; - - void setFunction(MachineFunction &Fn); - MachineFunction *getFunction() const { return MF; } - - static MachineBasicBlock *getEntryBlock(MachineFunction &F); - static void appendBlockDefs(SmallVectorImpl &defs, - const MachineBasicBlock &block); - static void appendBlockTerms(SmallVectorImpl &terms, - MachineBasicBlock &block); - static void appendBlockTerms(SmallVectorImpl &terms, - const MachineBasicBlock &block); - MachineBasicBlock *getDefBlock(Register) const; - static bool isConstantOrUndefValuePhi(const MachineInstr &Phi); - - Printable print(const MachineBasicBlock *Block) const; - Printable print(const MachineInstr *Inst) const; - Printable print(Register Value) const; }; using MachineSSAContext = GenericSSAContext; diff --git a/llvm/include/llvm/IR/SSAContext.h b/llvm/include/llvm/IR/SSAContext.h --- a/llvm/include/llvm/IR/SSAContext.h +++ b/llvm/include/llvm/IR/SSAContext.h @@ -17,60 +17,24 @@ #include "llvm/ADT/GenericSSAContext.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/ModuleSlotTracker.h" -#include "llvm/Support/Printable.h" - -#include namespace llvm { class BasicBlock; class Function; class Instruction; class Value; -template class SmallVectorImpl; -template class DominatorTreeBase; inline auto instrs(const BasicBlock &BB) { return llvm::make_range(BB.begin(), BB.end()); } -template <> class GenericSSAContext { - Function *F; - -public: +template <> struct GenericSSATraits { using BlockT = BasicBlock; using FunctionT = Function; using InstructionT = Instruction; using ValueRefT = Value *; using ConstValueRefT = const Value *; using UseT = Use; - using DominatorTreeT = DominatorTreeBase; - - static constexpr Value *ValueRefNull = nullptr; - - void setFunction(Function &Fn); - Function *getFunction() const { return F; } - - static BasicBlock *getEntryBlock(Function &F); - static const BasicBlock *getEntryBlock(const Function &F); - - static void appendBlockDefs(SmallVectorImpl &defs, - BasicBlock &block); - static void appendBlockDefs(SmallVectorImpl &defs, - const BasicBlock &block); - - static void appendBlockTerms(SmallVectorImpl &terms, - BasicBlock &block); - static void appendBlockTerms(SmallVectorImpl &terms, - const BasicBlock &block); - - static bool comesBefore(const Instruction *lhs, const Instruction *rhs); - static bool isConstantOrUndefValuePhi(const Instruction &Instr); - const BasicBlock *getDefBlock(const Value *value) const; - - Printable print(const BasicBlock *Block) const; - Printable print(const Instruction *Inst) const; - Printable print(const Value *Value) const; }; using SSAContext = GenericSSAContext; diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp --- a/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -118,7 +118,7 @@ auto &DT = FAM.getResult(F); auto &TTI = FAM.getResult(F); auto &CI = FAM.getResult(F); - UniformityInfo UI{F, DT, CI, &TTI}; + UniformityInfo UI{DT, CI, &TTI}; // Skip computation if we can assume everything is uniform. if (TTI.hasBranchDivergence(&F)) UI.compute(); @@ -171,8 +171,7 @@ getAnalysis().getTTI(F); m_function = &F; - m_uniformityInfo = - UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo}; + m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo}; // Skip computation if we can assume everything is uniform. if (targetTransformInfo.hasBranchDivergence(m_function)) diff --git a/llvm/lib/CodeGen/MachineSSAContext.cpp b/llvm/lib/CodeGen/MachineSSAContext.cpp --- a/llvm/lib/CodeGen/MachineSSAContext.cpp +++ b/llvm/lib/CodeGen/MachineSSAContext.cpp @@ -21,15 +21,23 @@ using namespace llvm; -void MachineSSAContext::setFunction(MachineFunction &Fn) { - MF = &Fn; - RegInfo = &MF->getRegInfo(); +template <> +void MachineSSAContext::appendBlockDefs(SmallVectorImpl &defs, + const MachineBasicBlock &block) { + for (auto &instr : block.instrs()) { + for (auto &op : instr.all_defs()) + defs.push_back(op.getReg()); + } } -MachineBasicBlock *MachineSSAContext::getEntryBlock(MachineFunction &F) { - return &F.front(); +template <> +void MachineSSAContext::appendBlockTerms(SmallVectorImpl &terms, + MachineBasicBlock &block) { + for (auto &T : block.terminators()) + terms.push_back(&T); } +template <> void MachineSSAContext::appendBlockTerms( SmallVectorImpl &terms, const MachineBasicBlock &block) { @@ -37,37 +45,32 @@ terms.push_back(&T); } -void MachineSSAContext::appendBlockDefs(SmallVectorImpl &defs, - const MachineBasicBlock &block) { - for (const MachineInstr &instr : block.instrs()) { - for (const MachineOperand &op : instr.all_defs()) - defs.push_back(op.getReg()); - } -} - /// Get the defining block of a value. -MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const { +template <> +const MachineBasicBlock *MachineSSAContext::getDefBlock(Register value) const { if (!value) return nullptr; - return RegInfo->getVRegDef(value)->getParent(); + return F->getRegInfo().getVRegDef(value)->getParent(); } +template <> bool MachineSSAContext::isConstantOrUndefValuePhi(const MachineInstr &Phi) { return Phi.isConstantValuePHI(); } +template <> Printable MachineSSAContext::print(const MachineBasicBlock *Block) const { if (!Block) return Printable([](raw_ostream &Out) { Out << ""; }); return Printable([Block](raw_ostream &Out) { Block->printName(Out); }); } -Printable MachineSSAContext::print(const MachineInstr *I) const { +template <> Printable MachineSSAContext::print(const MachineInstr *I) const { return Printable([I](raw_ostream &Out) { I->print(Out); }); } -Printable MachineSSAContext::print(Register Value) const { - auto *MRI = RegInfo; +template <> Printable MachineSSAContext::print(Register Value) const { + auto *MRI = &F->getRegInfo(); return Printable([MRI, Value](raw_ostream &Out) { Out << printReg(Value, MRI->getTargetRegisterInfo(), 0, MRI); diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp --- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp +++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp @@ -157,7 +157,7 @@ MachineFunction &F, const MachineCycleInfo &cycleInfo, const MachineDomTree &domTree, bool HasBranchDivergence) { assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); - MachineUniformityInfo UI(F, domTree, cycleInfo); + MachineUniformityInfo UI(domTree, cycleInfo); if (HasBranchDivergence) UI.compute(); return UI; diff --git a/llvm/lib/IR/SSAContext.cpp b/llvm/lib/IR/SSAContext.cpp --- a/llvm/lib/IR/SSAContext.cpp +++ b/llvm/lib/IR/SSAContext.cpp @@ -19,31 +19,21 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/IR/ModuleSlotTracker.h" using namespace llvm; -void SSAContext::setFunction(Function &Fn) { F = &Fn; } - -BasicBlock *SSAContext::getEntryBlock(Function &F) { - return &F.getEntryBlock(); -} - -const BasicBlock *SSAContext::getEntryBlock(const Function &F) { - return &F.getEntryBlock(); -} - +template <> void SSAContext::appendBlockDefs(SmallVectorImpl &defs, BasicBlock &block) { - for (auto &instr : block.instructionsWithoutDebug(/*SkipPseudoOp=*/true)) { + for (auto &instr : block) { if (instr.isTerminator()) break; - if (instr.getType()->isVoidTy()) - continue; - auto *def = &instr; - defs.push_back(def); + defs.push_back(&instr); } } +template <> void SSAContext::appendBlockDefs(SmallVectorImpl &defs, const BasicBlock &block) { for (auto &instr : block) { @@ -53,41 +43,41 @@ } } +template <> void SSAContext::appendBlockTerms(SmallVectorImpl &terms, BasicBlock &block) { terms.push_back(block.getTerminator()); } +template <> void SSAContext::appendBlockTerms(SmallVectorImpl &terms, const BasicBlock &block) { terms.push_back(block.getTerminator()); } +template <> const BasicBlock *SSAContext::getDefBlock(const Value *value) const { if (const auto *instruction = dyn_cast(value)) return instruction->getParent(); return nullptr; } -bool SSAContext::comesBefore(const Instruction *lhs, const Instruction *rhs) { - return lhs->comesBefore(rhs); -} - +template <> bool SSAContext::isConstantOrUndefValuePhi(const Instruction &Instr) { if (auto *Phi = dyn_cast(&Instr)) return Phi->hasConstantOrUndefValue(); return false; } -Printable SSAContext::print(const Value *V) const { +template <> Printable SSAContext::print(const Value *V) const { return Printable([V](raw_ostream &Out) { V->print(Out); }); } -Printable SSAContext::print(const Instruction *Inst) const { +template <> Printable SSAContext::print(const Instruction *Inst) const { return print(cast(Inst)); } -Printable SSAContext::print(const BasicBlock *BB) const { +template <> Printable SSAContext::print(const BasicBlock *BB) const { if (!BB) return Printable([](raw_ostream &Out) { Out << ""; }); if (BB->hasName())