diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.h b/llvm/lib/Target/AMDGPU/AMDGPU.h --- a/llvm/lib/Target/AMDGPU/AMDGPU.h +++ b/llvm/lib/Target/AMDGPU/AMDGPU.h @@ -336,6 +336,14 @@ void initializeGCNNSAReassignPass(PassRegistry &); extern char &GCNNSAReassignID; +ModulePass *createAMDGPULowerLDSGlobalPass(); +void initializeAMDGPULowerLDSGlobalPass(PassRegistry &); +extern char &AMDGPULowerLDSGlobalID; +struct AMDGPULowerLDSGlobalPass : PassInfoMixin { + AMDGPULowerLDSGlobalPass() {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + namespace AMDGPU { enum TargetIndex { TI_CONSTDATA_START, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp @@ -117,13 +117,15 @@ // should only appear when IPO passes manages to move LDs defined in a kernel // into a single user function. - for (GlobalVariable &GV : M.globals()) { - // TODO: Region address - unsigned AS = GV.getAddressSpace(); - if (AS != AMDGPUAS::LOCAL_ADDRESS && AS != AMDGPUAS::REGION_ADDRESS) - continue; - - recursivelyVisitUsers(GV, FuncsToAlwaysInline); + if (!AMDGPUTargetMachine::EnableLDSGlobalLowering) { + for (GlobalVariable &GV : M.globals()) { + // TODO: Region address + unsigned AS = GV.getAddressSpace(); + if (AS != AMDGPUAS::LOCAL_ADDRESS && AS != AMDGPUAS::REGION_ADDRESS) + continue; + + recursivelyVisitUsers(GV, FuncsToAlwaysInline); + } } if (!AMDGPUTargetMachine::EnableFunctionCalls || StressCalls) { diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerLDSGlobal.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerLDSGlobal.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerLDSGlobal.cpp @@ -0,0 +1,1054 @@ +//===-- AMDGPULowerLDSGlobal.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TODO +// +//===----------------------------------------------------------------------===// + +#include "AMDGPU.h" +#include "Utils/AMDGPUBaseInfo.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include +#include + +#define DEBUG_TYPE "amdgpu-lower-lds-global" + +using namespace llvm; + +// Helper function around `ValueMap` to detect if an element exists within it. +template +static bool contains(R &&VMap, const E &Element) { + return VMap.find(Element) != VMap.end(); +} + +// Helper function to get the required alignment for the input LDS variable +// `GV`. +static Align getAlign(GlobalVariable *GV) { + return GV->getAlign().getValueOr( + GV->getParent()->getDataLayout().getPreferredAlign(GV)); +} + +namespace { + +class LowerLDSGlobalImpl { + Module &M; + LLVMContext &Ctx; + const DataLayout &DL; + CallGraph CG; + Twine PrefixStr; + + // Holds kernels defined within the module `M`. + SmallPtrSet Kernels; + + // Holds LDS globals defined within the module `M`. + SmallPtrSet LDSGlobals; + + // Holds call graph nodes associated with the functions whose addresses are + // taken within the module. + SmallPtrSet AddressTakenSet; + + // Associates LDS global to a list of functions which references that LDS. + ValueMap> LDSGlobalToAccessors; + + // Associates function to a list of LDS globals which are referenced within + // that function. + ValueMap> AccessorToLDSGlobals; + + // Associates kernel to a list of functions which are reachable from that + // kernel. + ValueMap> KernelToCallees; + + // Associates kernel to a list LDS globals which are referenced along the call + // graph paths associated with that kernel. + ValueMap> KernelToLDSGlobals; + + // Associates kernel to a newly created per kernel LDS layout. + ValueMap KernelToLDSLayout; + + // Associates kernel to a map which maps LDS globals to corresponding offsets. + ValueMap> + KernelToLDSToOffset; + + // Associates kernel to instruction insertion point within that kernel. + ValueMap KernelToInstInsertPt; + + // Holds pointers to LDS layouts. + GlobalVariable *LDSLayouts; + + // Holds 2D LDS offset table. + GlobalVariable *LDSOffsetTable; + + // Associates kernel to unique integer id. + std::map KernelToID; + + // Associates an integer id uniquely to kernel. + std::map LDSToID; + + // Associates LDS global to unique integer id. + std::map IDToKernel; + + // Associates an integer id uniquely to LDS global. + std::map IDToLDS; + +public: + explicit LowerLDSGlobalImpl(Module &M) + : M(M), Ctx(M.getContext()), DL(M.getDataLayout()), CG(CallGraph(M)), + PrefixStr(Twine("lds.lower.")) { + // Collect the functions whose address is taken within the module. + collectAddressTakenFunctions(); + } + + // Entry-point function. + bool lower(); + +private: + //===--------------------------------------------------------------------===// + // Methods which aid in creating new instructions which assess LDS layouts. + //===--------------------------------------------------------------------===// + + // Convert `ConstantExpr CE` to `Instruction`, and update users of `CE` to use + // `Instruction`. + Instruction *replaceConstExprByInst(ConstantExpr *CE); + + // Within instruction `I` replace the use(s) of `LDS` by `LayoutAccessInst`. + void replaceInstWhichUsesLDS(Instruction *I, GlobalVariable *LDS, + Value *LayoutAccessInst, + SmallPtrSetImpl &ToBeErasedInsts); + + // Replace all the users of original LDS global `LDS` with new ones which + // access corresponding LDS layouts. + void replaceAllUsersOfLDS( + GlobalVariable *LDS, + ValueMap &FunctionToLDSLayoutAccessInst); + + // Add instructions to access LDS layouts corresponding to LDS global `LDS`. + bool accessLDSLayouts( + GlobalVariable *LDS, + ValueMap &FunctionToLDSLayoutAccessInst); + + // Add instructions within kernels to initialize LDS layouts array to hold the + // starting address of LDS layouts. + void initializeArrayOfLDSLayouts(); + + // Insert various required instructions within kernels and within callees + // which aid accessing of newly inserted LDS layouts. + void insertNewInstructions(); + + //===--------------------------------------------------------------------===// + // Methods which aid in creating new globals like kernel specific LDS layouts, + // 2D LDS offset table, etc. + //===--------------------------------------------------------------------===// + + // Create 2D LDS offset table which will be referenced at run time to access + // the LDS specific offset within kernel specific LDS layout. + void constructLDSOffsetTable(); + + // Create 1D array which holds LDS layouts which will be referenced at run + // time to access kernel specific LDS layout. + void constructArrayOfLDSLayouts(); + + // Create a new LDS layout corresponding to kernel `K` whose size will be of + // `LayoutSize`. + GlobalVariable *constructLDSLayout(Function *K, uint32_t LayoutSize); + + // Given a set of LDS globals, associate each LDS from the set to a + // corresponding offset within the kernel specific LDS layout. + uint32_t computeOffsets(SmallPtrSetImpl &LDSSet, + std::map &LDSToOffSet); + + // Create unified LDS layouts for each kernel. + void constructLDSLayouts(); + + // Insert various required globals (including per kernel specific unified LDS + // layouts) into the module. + void insertNewGlobals(); + + //===--------------------------------------------------------------------===// + // Methods which aid in preparing for lowering. + //===--------------------------------------------------------------------===// + + // Collect all call graph paths from kernel `K` to callee `Callee`. + void collectCallGraphPaths(Function *K, Function *Callee, + SmallVectorImpl> &CGPaths); + + // TODO: What is the neat way to implement this functionality? + void propogateKernelId(); + + // Associate kernel and LDS global to unique ids and vice versa. + void associateKernelAndLDSGlobalToUniqueId(); + + // Associate kernels and LDS globals with unique integer id, and implement a + // mechanism to access the kernel ids within callees at run time. + void prepareForLowering(); + + //===--------------------------------------------------------------------===// + // Methods which aid in creating the various `map` data structures. + //===--------------------------------------------------------------------===// + + // Associate each kernel K with LDS globals which are being accessed by K + // and/or by the callees of K. + void createKernelToLDSGlobalsMap(); + + // Traverse through the call graph nodes associated with the callees of + // current caller, and push them into stack. + void pushCallGraphNodes(CallGraphNode *CGNode, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack); + + // The call site associated with `CGNode` is a "direct call site", and is + // associated with a single callee, say,`Callee` represented by `CGNode`. Add + // `Callee` to `CalleeSet`, and push callees of `Callee` to `CGNodeStack` to + // further explore DFS search. + void collectCalleeAssociatedWithDirectCallSite( + CallGraphNode *CGNode, SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet); + + // The call site `CB` is an "indirect call site". Resolve `CB` to a set of + // potential callees. + void AssociateIndirectCallSiteWithPotentialCallees( + CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet); + + // Collect callee(s) associated with call site `CB`. If `CB` is a `direct` + // call site, then there is only one callee associated with it, collect it. If + // `CB` is an `indirect` call site, then potentially there could be more than + // one callee associated with it, collect all of them. + void collectCallees(CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet, + bool &CallSiteHandledIsIndirect); + + // Traverse `CallGraph` starting from the `CallGraphNode` associated with each + // kernel `K` in DFS manner and collect all the callees which are reachable + // from K (including indirectly called callees). + void createKernelToCalleesMap(); + + // Associate each kernel/function with the LDS globals which are being + // accessed within them. + void createAccessorToLDSGlobalsMap(); + + // For each `LDS`, recursively visit its user list and find all those + // kernels/functions within which the `LDS` is being accessed. + void createLDSGlobalToAccessorsMap(); + + // For each kernel `K`, collect LDS globals which are being accessed during + // the execution of `K`. + bool collectPerKernelAccessibleLDSGlobals(); + + //===--------------------------------------------------------------------===// + // Methods which aid in creating the various `set` data structures. + //===--------------------------------------------------------------------===// + + // Collect all the amdgpu kernels defined within the current module. + bool collectKernels(); + + // Collect all the (static) LDS globals defined within the current module. + bool collectLDSGlobals(); + + // Collect functions whose address is taken within the module. + void collectAddressTakenFunctions(); +}; + +// Convert `ConstantExpr CE` to `Instruction`, and update users of `CE` to use +// `Instruction`. +Instruction *LowerLDSGlobalImpl::replaceConstExprByInst(ConstantExpr *CE) { + // FIXME: This looks like a hack to me, but, is there any better way of + // handling `ConstantExprs`? I have no idea at the moment, need to revisit it + // later. + for (auto *U : CE->users()) { + auto *I = dyn_cast(U); + + if (!I) + I = replaceConstExprByInst(dyn_cast(U)); + + if (I) { + auto *NI = CE->getAsInstruction(); + NI->insertBefore(I); + unsigned Ind = 0; + for (Use &UU : I->operands()) { + Value *V = UU.get(); + if (V == CE) + I->setOperand(Ind, NI); + ++Ind; + } + return NI; + } + } + + return nullptr; +} + +// Within instruction `I` replace the use(s) of `LDS` by `LayoutAccessInst`. +void LowerLDSGlobalImpl::replaceInstWhichUsesLDS( + Instruction *I, GlobalVariable *LDS, Value *LayoutAccessInst, + SmallPtrSetImpl &ToBeErasedInsts) { + // Create clone of `I`, say, it is `NewI`. Within `NewI`, replace the use(s) + // of `LDS` by `LayoutAccessor`. + // + // FIXME: Instruction cloning is not required, fix it. + Instruction *NewI = I->clone(); + unsigned Ind = 0; + for (Use &UU : NewI->operands()) { + Value *V = UU.get(); + if (V == LDS) + NewI->setOperand(Ind, LayoutAccessInst); + ++Ind; + } + + // Insert `NewI` just before `I`, replace all uses of `I` by `NewI` and mark + // `I` as `to be erased` instruction. + NewI->insertBefore(I); + NewI->copyMetadata(*I); + I->replaceAllUsesWith(NewI); + ToBeErasedInsts.insert(I); +} + +// Replace all the users of original LDS global `LDS` with new ones which access +// corresponding LDS layouts. +void LowerLDSGlobalImpl::replaceAllUsersOfLDS( + GlobalVariable *LDS, + ValueMap &FunctionToLDSLayoutAccessInst) { + // Keep track of all the erased to be instructions. + SmallPtrSet ToBeErasedInsts; + + // Traverse through each use `U` of `LDS`, create a new one which replaces `U` + // and accordingly replace `U`. + for (auto *U : LDS->users()) { + // `U` may be using `LDS`, but 'U` itself is not used anywhere, ignore `U`. + if (!U->getNumUses()) + continue; + + if (auto *I = dyn_cast(U)) { + // User is an Instruction, accordingly handle it. + auto *LayoutAccessInst = + FunctionToLDSLayoutAccessInst[I->getParent()->getParent()]; + replaceInstWhichUsesLDS(I, LDS, LayoutAccessInst, ToBeErasedInsts); + } else if (auto *CE = dyn_cast(U)) { + // User is a ConstantExpr, accordingly handle it. + auto *I = replaceConstExprByInst(CE); + auto *LayoutAccessInst = + FunctionToLDSLayoutAccessInst[I->getParent()->getParent()]; + replaceInstWhichUsesLDS(I, LDS, LayoutAccessInst, ToBeErasedInsts); + CE->removeDeadConstantUsers(); + } else + llvm_unreachable("Not Implemented."); // TODO: What else is missing? + } + + // Erase all the instructions which are got replaced by new ones. + for (auto *I : ToBeErasedInsts) + I->eraseFromParent(); +} + +// Add instructions to access LDS layouts corresponding to LDS global `LDS`. +bool LowerLDSGlobalImpl::accessLDSLayouts( + GlobalVariable *LDS, + ValueMap &FunctionToLDSLayoutAccessInst) { + // LDS global is not used anywhere? ignore it. + if (!contains(LDSGlobalToAccessors, LDS)) + return false; + + for (auto *F : LDSGlobalToAccessors[LDS]) { + // Get instruction insertion point. + Instruction *EI = AMDGPU::isModuleEntryFunctionCC(F->getCallingConv()) + ? KernelToInstInsertPt[F] + : &(*F->getEntryBlock().getFirstInsertionPt()); + + // Get indices to access LDS offset table. + // FIXME: explicit cast + Value *KI = AMDGPU::isModuleEntryFunctionCC(F->getCallingConv()) + ? (Value *)Constant::getIntegerValue( + Type::getInt32Ty(Ctx), APInt(32, KernelToID[F])) + : (Value *)F->getArg(F->arg_size() - 1); + Value *LI = Constant::getIntegerValue(Type::getInt32Ty(Ctx), + APInt(32, LDSToID[LDS])); + + // Insert GEP instruction to access the LDS offset table at the address + // ((LDSOffsetTable + KI) + LI) which is of type `i32*`. + Value *Indices1[] = {Constant::getNullValue(Type::getInt32Ty(Ctx)), KI, LI}; + Instruction *GEPI1 = GetElementPtrInst::CreateInBounds( + LDSOffsetTable->getValueType(), LDSOffsetTable, Indices1, + PrefixStr + Twine("gep."), EI); + + // Insert LOAD instruction to load the `offset` value from LDS offset table + // at the address ((LDSOffsetTable + KI) + LI) which is of type `i32`. + // FIXME: getPointerElementType() sould not be used. + Instruction *LoadI = new LoadInst(GEPI1->getType()->getPointerElementType(), + GEPI1, PrefixStr + Twine("load."), EI); + + // Insert GEP instruction to access the starting address of LDS layout + // corresponding to function `F` which is of type `i8*`. + Value *Indices2[] = {Constant::getNullValue(Type::getInt32Ty(Ctx)), KI}; + Instruction *GEPI2 = GetElementPtrInst::CreateInBounds( + LDSLayouts->getValueType(), LDSLayouts, Indices2, + PrefixStr + Twine("gep."), EI); + + // Insert GEP instruction to access the base address corresponding to LDS + // global `LDS` within LDS layout which is of type 'i8*`. + // FIXME: getPointerElementType() sould not be used. + Instruction *GEPI3 = GetElementPtrInst::CreateInBounds( + GEPI2->getType()->getPointerElementType(), GEPI2, LoadI, + PrefixStr + Twine("gep."), EI); + + // Insert BITCAST instruction to cast `above base address` from 'i8*` to + // `ldstype*` where `ldstype` is the type of original LDS global `LDS`. + Instruction *CastI = + new BitCastInst(GEPI3, LDS->getType(), PrefixStr + Twine("cast."), EI); + + // Save the above instruction `CastI` which replaces all uses of `LDS` + // within `F`. + FunctionToLDSLayoutAccessInst[F] = CastI; + } + + return true; +} + +// Add instructions within kernels to initialize LDS layouts array to hold the +// starting address of LDS layouts. +void LowerLDSGlobalImpl::initializeArrayOfLDSLayouts() { + for (auto *K : Kernels) { + // Kernel `K` does not associate with LDS layout? skip it. + if (!contains(KernelToLDSLayout, K)) + continue; + + // Insert instructions at the beginning of the entry basic block of the + // kernel `K` which initialize the global array `LDSLayouts` to hold the + // starting address of the LDS layout associated with `K` at LDSLayouts[ID] + // where ID is kernel id of `K`. + Instruction *EI = &(*(K->getEntryBlock().getFirstInsertionPt())); + + // Insert GEP instruction which access the address `LDSLayout + 0`, say, the + // result is `GEP1` which is of type `i8*`. + GlobalVariable *LDSLayout = KernelToLDSLayout[K]; + Value *Indices1[] = { + Constant::getNullValue(Type::getInt32Ty(Ctx)), + Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0))}; + Instruction *GEP1 = GetElementPtrInst::CreateInBounds( + LDSLayout->getValueType(), const_cast(LDSLayout), + Indices1, PrefixStr + Twine("gep."), EI); + + // Insert GEP instruction which access the address `LDSLayouts + ID`, say, + // the result is `GEP2` which is of type `i8**`. + Value *Indices2[] = {Constant::getNullValue(Type::getInt32Ty(Ctx)), + Constant::getIntegerValue(Type::getInt32Ty(Ctx), + APInt(32, KernelToID[K]))}; + Instruction *GEP2 = GetElementPtrInst::CreateInBounds( + LDSLayouts->getValueType(), const_cast(LDSLayouts), + Indices2, PrefixStr + Twine("gep."), EI); + + // Insert STORE instruction which stores `GEP1` at `GEP2`. + new StoreInst(GEP1, GEP2, EI); + + // Save the instruction insertion point which will be later required when + // it is necessary to insert LDS layout access instructions within kernel. + KernelToInstInsertPt[K] = EI; + } +} + +// Insert various required instructions within kernels and within callees which +// aid accessing of newly inserted LDS layouts. +void LowerLDSGlobalImpl::insertNewInstructions() { + // Add instructions within kernels to initialize LDS layouts array to hold the + // starting address of LDS layouts. + initializeArrayOfLDSLayouts(); + + // Add instructions to access LDS layouts and accordingly replace original + // instructions which use original LDS globals. + for (auto *LDS : LDSGlobals) { + ValueMap FunctionToLDSLayoutAccessInst; + + // LDS global is not used anywhere? ignore it. + if (!accessLDSLayouts(LDS, FunctionToLDSLayoutAccessInst)) + continue; + + // Replace all the users of LDS global `LDS` with new ones which access + // corresponding LDS layouts. + replaceAllUsersOfLDS(LDS, FunctionToLDSLayoutAccessInst); + } +} + +// Create 2D LDS offset table which will be referenced at run time to access the +// LDS specific offset within kernel specific LDS layout. +void LowerLDSGlobalImpl::constructLDSOffsetTable() { + // Get type of 2D LDS offset table. + auto *EleTy = Type::getInt32Ty(Ctx); + auto *Arr1DTy = ArrayType::get(EleTy, LDSGlobals.size()); + auto *Arr2DTy = ArrayType::get(Arr1DTy, Kernels.size()); + + // Create offset initialization list. + SmallVector Init2DValues; + + for (unsigned K = 0; K < Kernels.size(); ++K) { + SmallVector Init1DValues; + auto *Kernel = IDToKernel[K]; + + for (unsigned L = 0; L < LDSGlobals.size(); ++L) { + auto *LDS = IDToLDS[L]; + auto Offset = contains(KernelToLDSToOffset, Kernel) + ? contains(KernelToLDSToOffset[Kernel], LDS) + ? KernelToLDSToOffset[Kernel][LDS] + : -1 + : -1; + auto *C = Constant::getIntegerValue(EleTy, APInt(32, Offset)); + Init1DValues.push_back(C); + } + + auto *Const1D = ConstantArray::get(Arr1DTy, Init1DValues); + Init2DValues.push_back(Const1D); + } + + auto *Const2D = ConstantArray::get(Arr2DTy, Init2DValues); + + // Create 2D LDS offset table which is initialized with offsets. + LDSOffsetTable = new GlobalVariable( + M, Arr2DTy, false, GlobalValue::InternalLinkage, Const2D, + Twine("__llvm.amdgcn.lds.offset.table__"), nullptr, + GlobalVariable::NotThreadLocal, AMDGPUAS::CONSTANT_ADDRESS); + + // Set proper alignment. + LDSOffsetTable->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + LDSOffsetTable->setAlignment(getAlign(LDSOffsetTable)); +} + +// Create 1D array which holds LDS layouts which will be referenced at run time +// to access kernel specific LDS layout. +void LowerLDSGlobalImpl::constructArrayOfLDSLayouts() { + auto *EleTy = + PointerType::get(IntegerType::get(Ctx, 8), AMDGPUAS::LOCAL_ADDRESS); + auto *ArrTy = ArrayType::get(EleTy, Kernels.size()); + + LDSLayouts = new GlobalVariable( + M, ArrTy, false, GlobalValue::InternalLinkage, UndefValue::get(ArrTy), + Twine("__llvm.amdgcn.unified.lds.layouts__"), nullptr, + GlobalVariable::NotThreadLocal, AMDGPUAS::CONSTANT_ADDRESS); + + LDSLayouts->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + LDSLayouts->setAlignment(getAlign(LDSLayouts)); +} + +// Create a new LDS layout corresponding to kernel `K` whose size will be of +// `LayoutSize`. +GlobalVariable *LowerLDSGlobalImpl::constructLDSLayout(Function *K, + uint32_t LayoutSize) { + auto LayoutName = Twine("__llvm.amdgcn.unified.lds.layout.") + + Twine(KernelToID[K]) + Twine("__"); + auto *LayoutType = ArrayType::get(IntegerType::get(Ctx, 8), LayoutSize); + + auto *LDSLayout = new GlobalVariable( + M, LayoutType, false, GlobalValue::InternalLinkage, + UndefValue::get(LayoutType), LayoutName, nullptr, + GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS); + + LDSLayout->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + LDSLayout->setAlignment(getAlign(LDSLayout)); + + return LDSLayout; +} + +// Given a set of LDS globals, associate each LDS from the set to a +// corresponding offset within the kernel specific LDS layout. +uint32_t LowerLDSGlobalImpl::computeOffsets( + SmallPtrSetImpl &LDSSet, + std::map &LDSToOffSet) { + // Sort `LDSSet` by alignment in descending order to minimize the padding + // required. On ties, sort by type allocation size, again in descending order, + // and finally, by name in lexicographical order (in ascending order) when + // type allocation sizes are same. + std::vector SortedLDSSet(LDSSet.begin(), LDSSet.end()); + llvm::stable_sort(SortedLDSSet, + [&](GlobalVariable *LHS, GlobalVariable *RHS) -> bool { + auto ALHS = getAlign(LHS); + auto ARHS = getAlign(RHS); + if (ALHS != ARHS) + return ALHS > ARHS; + + auto SLHS = DL.getTypeAllocSize(LHS->getValueType()); + auto SRHS = DL.getTypeAllocSize(RHS->getValueType()); + if (SLHS != SRHS) + return SLHS > SRHS; + + return LHS->getName() < RHS->getName(); + }); + + // Compute `offset` for each LDS global from `SortedLDSSet`. + // FIXME: Alignment is not yet handled properly. + uint32_t CurOffset = 0; + for (auto *LDS : SortedLDSSet) { + LDSToOffSet[LDS] = CurOffset; + CurOffset += DL.getTypeAllocSize(LDS->getValueType()).getFixedValue(); + } + + return CurOffset; +} + +// Create unified LDS layouts for each kernel. +void LowerLDSGlobalImpl::constructLDSLayouts() { + // Create a new LDS layout corresponding to each kernel `K`, and associate + // each kernel to a map which maps LDS globals to corresponding offsets. + for (auto *K : Kernels) { + // Kernel `K` does not associate with any LDS global? skip it. + if (!contains(KernelToLDSGlobals, K)) + continue; + + // For the set of LDS globals associated with kernel `K`, map each LDS + // from the set to a corresponding offset within the kernel specific LDS + // layout, and also get the grand total size of LDS layout. + std::map LDSToOffSet; + auto LayoutSize = computeOffsets(KernelToLDSGlobals[K], LDSToOffSet); + KernelToLDSToOffset[K] = LDSToOffSet; + + // Create a new LDS layout corresponding to kernel `K` whose size will be of + // `LayoutSize`. + KernelToLDSLayout[K] = constructLDSLayout(K, LayoutSize); + } +} + +// Insert various required globals (including per kernel specific unified LDS +// layouts) into the module. +void LowerLDSGlobalImpl::insertNewGlobals() { + // Create unified LDS layouts for each kernel. + constructLDSLayouts(); + + // Create 1D array which holds LDS layouts which will be referenced at run + // time to access kernel specific LDS layout. + constructArrayOfLDSLayouts(); + + // Create 2D LDS offset table which will be referenced at run time to access + // the LDS specific offset within kernel specific LDS layout. + constructLDSOffsetTable(); +} + +// TODO: What is the neat way to implement this functionality? +void LowerLDSGlobalImpl::propogateKernelId() { + // TODO +} + +// Associate kernel and LDS global to unique ids and vice versa. +void LowerLDSGlobalImpl::associateKernelAndLDSGlobalToUniqueId() { + // Associate each kernel to unique integer and vice versa. + uint32_t KNum = 0; + for (auto *K : Kernels) { + KernelToID[K] = KNum; + IDToKernel[KNum] = K; + ++KNum; + } + + // Associate each LDS to unique integer and vice versa. + uint32_t LNum = 0; + for (auto *LDSGlobal : LDSGlobals) { + LDSToID[LDSGlobal] = LNum; + IDToLDS[LNum] = LDSGlobal; + ++LNum; + } +} + +// Associate kernels and LDS globals with unique integer id, and implement a +// mechanism to access the kernel ids within callees at run time. +void LowerLDSGlobalImpl::prepareForLowering() { + // Associate kernel and LDS global to unique ids and vice versa. + associateKernelAndLDSGlobalToUniqueId(); + + // TODO: What is the neat way to implement this functionality? + propogateKernelId(); +} + +// Associate each kernel K with LDS globals which are being accessed by K and/or +// by the callees of K. +void LowerLDSGlobalImpl::createKernelToLDSGlobalsMap() { + for (auto *K : Kernels) { + SmallPtrSet LDSSet; + + // Collect all those LDS globals which are being accessed by kernel K + // itself. + if (contains(AccessorToLDSGlobals, K)) + LDSSet.insert(AccessorToLDSGlobals[K].begin(), + AccessorToLDSGlobals[K].end()); + + // Collect all those LDS globals which are being accessed by the callees of + // kernel K. + for (auto *Callee : KernelToCallees[K]) { + if (contains(AccessorToLDSGlobals, Callee)) + LDSSet.insert(AccessorToLDSGlobals[Callee].begin(), + AccessorToLDSGlobals[Callee].end()); + } + + if (!LDSSet.empty()) + KernelToLDSGlobals[K] = LDSSet; + } +} + +// Traverse through the call graph nodes associated with the callees of current +// caller, and push them into stack. +void LowerLDSGlobalImpl::pushCallGraphNodes( + CallGraphNode *CGNode, SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack) { + assert(CGNode && + "Call graph node associated with kernel/function definition cannot be " + "null"); + for (auto GI = CGNode->begin(), GE = CGNode->end(); GI != GE; ++GI) { + auto *CGN = GI->second; + assert(CGN && + "Call graph node associated with kernel/function definition cannot " + "be null"); + auto *CB = cast(GI->first.getValue()); + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } +} + +// The call site associated with `CGNode` is a "direct call site", and is +// associated with a single callee, say,`Callee` represented by `CGNode`. Add +// `Callee` to `CalleeSet`, and push callees of `Callee` to `CGNodeStack` to +// further explore DFS search. +void LowerLDSGlobalImpl::collectCalleeAssociatedWithDirectCallSite( + CallGraphNode *CGNode, SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet) { + auto *Callee = CGNode->getFunction(); + assert(Callee && "Expected a valid callee associated with call site"); + if (!Callee->isDeclaration()) { + CalleeSet.insert(Callee); + pushCallGraphNodes(CGNode, CGNodeStack, CallBaseStack); + } +} + +// The call site `CB` is an "indirect call site". Resolve `CB` to a set of +// potential callees. +void LowerLDSGlobalImpl::AssociateIndirectCallSiteWithPotentialCallees( + CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet) { + if (auto *MD = CB->getMetadata(LLVMContext::MD_callees)) { + // The metadata "!callee" is available at the indirect call site `CB`, which + // means, all the potential target callees for the call site `CB` is + // successfully resolved at compile time. So, push them into stack so that + // they will be handled just like direct callees when they are eventually + // poped out. + for (const auto &Op : MD->operands()) { + auto *CGN = CG[mdconst::extract_or_null(Op)]; + assert(CGN && + "Call graph node associated with kernel/function definition cannot" + " be null"); + assert(CGN->getFunction() && + "Expected a valid function which is included within !callee " + "metadata"); + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } + } else { + // The metadata "!callee" is *NOT* available at the indirect call site `CB`, + // which means, `CB` has *NO* information about potential target callees. + // The simplest possible *SAFE* assumption that we can make here is to + // consider all those "address taken" functions whose singature matches with + // that of the call site `CB`, and assume that all these signature matched + // "address taken" functions are possible potential callees. Thus, push all + // the signature matched "address taken" functions into stack so that they + // will be handled just like direct callees when they are eventually poped + // out. + auto *CBFTy = CB->getFunctionType(); + for (auto *CGN : AddressTakenSet) { + auto *F = CGN->getFunction(); + assert(F && "Expected a valid address taken function"); + auto *ADFTy = F->getFunctionType(); + if (ADFTy == CBFTy) { + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } + } + } +} + +// Collect callee(s) associated with call site `CB`. If `CB` is a `direct` call +// site, then there is only one callee associated with it, collect it. If `CB` +// is an `indirect` call site, then potentially there could be more than one +// callee associated with it, collect all of them. +void LowerLDSGlobalImpl::collectCallees( + CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + SmallPtrSetImpl &CalleeSet, bool &CallSiteHandledIsIndirect) { + if (!CB->getCalledFunction()) { + // Call site `CB` is an indirect call site. But, if the `CGNode` has a + // function defintion, say, `Callee`, associated with it, which means, we + // have already had visited `CB` earlier, and we had resovled it to a set of + // pontential callees, and `Callee` is one among them. Collect `Callee` just + // the way direct callee is collected. Otherwise, `CB` is encoutered for + // first time. Resolve it to a set of potential callees before collecting + // them. + if (!CGNode->getFunction()) { + // Indirect call site `CB` is encoutered first time, resolve it to a set + // of potential callees. + AssociateIndirectCallSiteWithPotentialCallees(CGNode, CB, CGNodeStack, + CallBaseStack, CalleeSet); + CallSiteHandledIsIndirect = true; + } else { + // Indirect call site `CB` is already been resolved to a set of potential + // callees during its first visit. The callee represented by `CGNode` is + // one among them. Collect it just the way direct callee is collected. + collectCalleeAssociatedWithDirectCallSite(CGNode, CGNodeStack, + CallBaseStack, CalleeSet); + } + } else { + // Call site `CB` is a direct call site. Collect a single callee which is + // associated with it. + collectCalleeAssociatedWithDirectCallSite(CGNode, CGNodeStack, + CallBaseStack, CalleeSet); + } +} + +// Traverse `CallGraph` starting from the `CallGraphNode` associated with each +// kernel `K` in DFS manner and collect all the callees which are reachable from +// K (including indirectly called callees). +// +// FIXME: Can we "really" get rid of this function, and is it that llvm +// "CallGraph" infra structure already provides this functionality, especially, +// by including indirect callees within SCC? +void LowerLDSGlobalImpl::createKernelToCalleesMap() { + for (auto *K : Kernels) { + auto *KernCGNode = CG[K]; + SmallVector CGNodeStack; + SmallVector CallBaseStack; + SmallPtrSet Visited; + SmallPtrSet CalleeSet; + + // Push the `CallGraphNode` associated with all the callees of the kernel`K` + // into into `CGNodeStack`, and the corresponding call sites into + // `CallBaseStack`. + pushCallGraphNodes(KernCGNode, CGNodeStack, CallBaseStack); + + // Continue DFS search until no more call graph nodes to handle. + while (!CGNodeStack.empty()) { + assert(CGNodeStack.size() == CallBaseStack.size() && + "Stack holding CallBase pointers is currupted"); + auto *CGNode = CGNodeStack.pop_back_val(); + auto *CB = CallBaseStack.pop_back_val(); + + // `CGNode` is already visited and handled, ignore it and proceed to next + // one. + if (!Visited.insert(CGNode).second) + continue; + + // Collect callee(s) associated with call site `CB`. If `CB` is a `direct` + // call site, then there is only one callee associated with it, collect + // it. If `CB` is an `indirect` call site, then potentially there could be + // more than one callee associated with it, collect all of them. + bool CallSiteHandledIsIndirect = false; + collectCallees(CGNode, CB, CGNodeStack, CallBaseStack, CalleeSet, + CallSiteHandledIsIndirect); + + // The call site `CB` which is being handled just now is an indirect call + // site. Since the indirect call site does not bind to any particular set + // of callees, and since the `CGNode` is one and the same for all the + // function pointers which have same signature, we should *NOT* assume + // that `CGNode` is visited unlike in case of direct call site. + if (CallSiteHandledIsIndirect) + Visited.erase(CGNode); + } + + assert(CallBaseStack.empty() && + "Stack holding CallBase pointers is currupted"); + + KernelToCallees[K] = CalleeSet; + } +} + +// Associate each kernel/function with the LDS globals which are being accessed +// within them. +void LowerLDSGlobalImpl::createAccessorToLDSGlobalsMap() { + for (auto LI = LDSGlobalToAccessors.begin(), LE = LDSGlobalToAccessors.end(); + LI != LE; ++LI) { + auto *LDS = LI->first; + for (auto *A : LI->second) { + if (!contains(AccessorToLDSGlobals, A)) { + SmallPtrSet LDSSet; + LDSSet.insert(LDS); + AccessorToLDSGlobals[A] = LDSSet; + } else + AccessorToLDSGlobals[A].insert(LDS); + } + } +} + +// For each `LDS`, recursively visit its user list and find all those +// kernels/functions within which the `LDS` is being accessed. +void LowerLDSGlobalImpl::createLDSGlobalToAccessorsMap() { + for (auto *LDS : LDSGlobals) { + assert(!LDS->user_empty() && + "LDS user list cannot be empty since it must have been successfully " + "defined within either kernel or function"); + + SmallPtrSet LDSAccessors; + SmallVector UserStack(LDS->users()); + SmallPtrSet Visited; + + while (!UserStack.empty()) { + auto *U = UserStack.pop_back_val(); + if (!Visited.insert(U).second) + continue; + + auto *I = dyn_cast(U); + + // If `U` is not an `Instruction`, then it should be a `Constant` which is + // nested within an `Instruction`. Push-back users of `U`, and continue + // further exploring the stack until an `Instruction` is found. + if (!I) { + assert(isa(U) && "Expected a constant expression"); + append_range(UserStack, U->users()); + continue; + } + + // We have successfully found a kernel/function within which the `LDS` is + // being accessed, insert it into `LDSAccessors` set. + LDSAccessors.insert(I->getParent()->getParent()); + } + + LDSGlobalToAccessors[LDS] = LDSAccessors; + } +} + +// For each kernel `K`, collect LDS globals which are being accessed during the +// execution of `K`. +bool LowerLDSGlobalImpl::collectPerKernelAccessibleLDSGlobals() { + // Associate each LDS with the kernels/functions within which the LDS is being + // accessed. + createLDSGlobalToAccessorsMap(); + + // Associate each kernel/function with the LDS globals which are being + // accessed within them. + createAccessorToLDSGlobalsMap(); + + // Associate each kernel K with callees which are reachable from K (including + // indirectly called callees). + createKernelToCalleesMap(); + + // Associate each kernel K with LDS globals which are being accessed by K + // and/or by the callees of K. + createKernelToLDSGlobalsMap(); + + // If *none* of the kernels associate with any LDS globals, then nothing do. + return !KernelToLDSGlobals.empty(); +} + +// Collect all the amdgpu kernels defined within the current module. +bool LowerLDSGlobalImpl::collectKernels() { + for (auto &F : M.functions()) { + if (!F.isDeclaration() && + AMDGPU::isModuleEntryFunctionCC(F.getCallingConv())) + Kernels.insert(&F); + } + + return !Kernels.empty(); +} + +// Collect all the (static) LDS globals defined within the current module. +bool LowerLDSGlobalImpl::collectLDSGlobals() { + for (auto &GV : M.globals()) { + if (GV.getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS && + !GV.isDeclaration() && !GV.getType()->isEmptyTy()) + LDSGlobals.insert(&GV); + } + + return !LDSGlobals.empty(); +} + +// Collect functions whose address is taken within the module. +void LowerLDSGlobalImpl::collectAddressTakenFunctions() { + auto *ExternalCallingNode = CG.getExternalCallingNode(); + assert(ExternalCallingNode && + "Call graph node associated with kernel/function definition cannot be " + "null"); + + for (auto GI = ExternalCallingNode->begin(), GE = ExternalCallingNode->end(); + GI != GE; ++GI) { + auto *CGN = GI->second; + assert(CGN && + "Call graph node associated with kernel/function definition cannot " + "be null"); + auto *F = CGN->getFunction(); + // FIXME: Anything else need to be excluded? + if (!F || F->isDeclaration() || + AMDGPU::isModuleEntryFunctionCC(F->getCallingConv())) + continue; + AddressTakenSet.insert(CGN); + } +} + +// Entry-point function. +bool LowerLDSGlobalImpl::lower() { + // If there are *no* LDS globals defined within the module, or if there are + // *no* kernels defined within the module, or if there exist *no* kernel + // *execution* which accesses LDS globals at run time, then nothing to do. + if (!collectLDSGlobals() || !collectKernels() || + !collectPerKernelAccessibleLDSGlobals()) + return false; + + // Associate kernels and LDS globals with unique integer id, and implement a + // mechanism to access the kernel ids within callees at run time. + prepareForLowering(); + + // Insert various required globals (including per kernel specific unified LDS + // layouts) into the module. + insertNewGlobals(); + + // Insert various required instructions within kernels and within callees + // which aid accessing of newly inserted LDS layouts. + insertNewInstructions(); + + return true; +} + +class AMDGPULowerLDSGlobal : public ModulePass { +public: + static char ID; + + AMDGPULowerLDSGlobal() : ModulePass(ID) { + initializeAMDGPULowerLDSGlobalPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; +}; + +} // namespace + +char AMDGPULowerLDSGlobal::ID = 0; +char &llvm::AMDGPULowerLDSGlobalID = AMDGPULowerLDSGlobal::ID; + +INITIALIZE_PASS(AMDGPULowerLDSGlobal, "amdgpu-lower-lds-global", + "Lower LDS Global Variables", false /*only look at the cfg*/, + false /*analysis pass*/) + +bool AMDGPULowerLDSGlobal::runOnModule(Module &M) { + LowerLDSGlobalImpl LDSLowerer{M}; + return LDSLowerer.lower(); +} + +ModulePass *llvm::createAMDGPULowerLDSGlobalPass() { + return new AMDGPULowerLDSGlobal(); +} + +PreservedAnalyses AMDGPULowerLDSGlobalPass::run(Module &M, + ModuleAnalysisManager &AM) { + LowerLDSGlobalImpl LDSLowerer{M}; + LDSLowerer.lower(); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h @@ -35,6 +35,7 @@ static bool EnableLateStructurizeCFG; static bool EnableFunctionCalls; static bool EnableFixedFunctionABI; + static bool EnableLDSGlobalLowering; AMDGPUTargetMachine(const Target &T, const Triple &TT, StringRef CPU, StringRef FS, TargetOptions Options, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -193,6 +193,12 @@ cl::desc("Enable workarounds for the StructurizeCFG pass"), cl::init(true), cl::Hidden); +static cl::opt EnableLDSGlobalLowering( + "amdgpu-enable-lds-global-lowering", + cl::desc("Enable LDS global variable lowering pass"), + cl::location(AMDGPUTargetMachine::EnableLDSGlobalLowering), cl::init(true), + cl::Hidden); + extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAMDGPUTarget() { // Register the target RegisterTargetMachine X(getTheAMDGPUTarget()); @@ -259,6 +265,7 @@ initializeGCNRegBankReassignPass(*PR); initializeGCNNSAReassignPass(*PR); initializeSIAddIMGInitPass(*PR); + initializeAMDGPULowerLDSGlobalPass(*PR); } static std::unique_ptr createTLOF(const Triple &TT) { @@ -388,6 +395,7 @@ bool AMDGPUTargetMachine::EnableLateStructurizeCFG = false; bool AMDGPUTargetMachine::EnableFunctionCalls = false; bool AMDGPUTargetMachine::EnableFixedFunctionABI = false; +bool AMDGPUTargetMachine::EnableLDSGlobalLowering = false; AMDGPUTargetMachine::~AMDGPUTargetMachine() = default; @@ -501,6 +509,10 @@ PM.addPass(AMDGPUAlwaysInlinePass()); return true; } + if (PassName == "amdgpu-lower-lds-global") { + PM.addPass(AMDGPULowerLDSGlobalPass()); + return true; + } return false; }); PB.registerPipelineParsingCallback( @@ -847,6 +859,12 @@ disablePass(&FuncletLayoutID); disablePass(&PatchableFunctionID); + // We expect to run this pass as a first AMDGPU IR pass so that new + // instructions being added in this pass can possibly undergo further + // transformations via subsequent passes. + if (EnableLDSGlobalLowering) + addPass(createAMDGPULowerLDSGlobalPass()); + addPass(createAMDGPUPrintfRuntimeBinding()); // This must occur before inlining, as the inliner will not look through diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt --- a/llvm/lib/Target/AMDGPU/CMakeLists.txt +++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt @@ -67,6 +67,7 @@ AMDGPULowerIntrinsics.cpp AMDGPULowerKernelArguments.cpp AMDGPULowerKernelAttributes.cpp + AMDGPULowerLDSGlobal.cpp AMDGPUMachineCFGStructurizer.cpp AMDGPUMachineFunction.cpp AMDGPUMachineModuleInfo.cpp