diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.h b/llvm/lib/Target/AMDGPU/AMDGPU.h --- a/llvm/lib/Target/AMDGPU/AMDGPU.h +++ b/llvm/lib/Target/AMDGPU/AMDGPU.h @@ -337,6 +337,15 @@ void initializeGCNNSAReassignPass(PassRegistry &); extern char &GCNNSAReassignID; +ModulePass *createAMDGPULowerFunctionLocalLDSPass(); +void initializeAMDGPULowerFunctionLocalLDSPass(PassRegistry &); +extern char &AMDGPULowerFunctionLocalLDSID; +struct AMDGPULowerFunctionLocalLDSPass + : PassInfoMixin { + AMDGPULowerFunctionLocalLDSPass() {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + namespace AMDGPU { enum TargetIndex { TI_CONSTDATA_START, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp b/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUAlwaysInlinePass.cpp @@ -117,13 +117,15 @@ // should only appear when IPO passes manages to move LDs defined in a kernel // into a single user function. - for (GlobalVariable &GV : M.globals()) { - // TODO: Region address - unsigned AS = GV.getAddressSpace(); - if (AS != AMDGPUAS::LOCAL_ADDRESS && AS != AMDGPUAS::REGION_ADDRESS) - continue; - - recursivelyVisitUsers(GV, FuncsToAlwaysInline); + if (!AMDGPUTargetMachine::EnableFunctionLocalLDSLowering) { + for (GlobalVariable &GV : M.globals()) { + // TODO: Region address + unsigned AS = GV.getAddressSpace(); + if (AS != AMDGPUAS::LOCAL_ADDRESS && AS != AMDGPUAS::REGION_ADDRESS) + continue; + + recursivelyVisitUsers(GV, FuncsToAlwaysInline); + } } if (!AMDGPUTargetMachine::EnableFunctionCalls || StressCalls) { diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerFunctionLocalLDS.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerFunctionLocalLDS.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/AMDGPU/AMDGPULowerFunctionLocalLDS.cpp @@ -0,0 +1,592 @@ +//===-- AMDGPULowerFunctionLocalLDS.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "AMDGPU.h" +#include "Utils/AMDGPUBaseInfo.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ValueMap.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include +#include + +#define DEBUG_TYPE "amdgpu-lower-function-local-lds" + +using namespace llvm; + +static bool isKernel(Function *F) { + if (AMDGPU::isModuleEntryFunctionCC(F->getCallingConv())) + return true; + + return false; +} + +// Asserts that call graph node associated with a kernel or a function defintion +// cannot be null. +static void assertCallGraphNodePtrIsNonNull(CallGraphNode *CGN) { + assert(CGN && "Call graph node associated with kernel/function definition " + "cannot be null.\n"); +} + +// Traverse through the call graph nodes associated with the callees of current +// caller, and push them into stack. +static void pushCallGraphNodes(CallGraphNode *CGNode, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack) { + assertCallGraphNodePtrIsNonNull(CGNode); + for (auto GI = CGNode->begin(), GE = CGNode->end(); GI != GE; ++GI) { + auto *CGN = GI->second; + auto *CB = dyn_cast(GI->first.getValue()); + assertCallGraphNodePtrIsNonNull(CGN); + assert(CB && "Call base associated with a call site within call graph " + "cannot be null\n"); + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } +} + +namespace { + +class LowerFunctionLocalLDSImpl { + // Object constructors and entry-point functions. +public: + // Constructs a LowerFunctionLocalLDSImpl object for the given Module. + explicit LowerFunctionLocalLDSImpl(Module &M) : M(M), CG(CallGraph(M)) {} + + // Entry-point function. + bool lower(); + + // Internal data structures which are being constructed while constructing + // the class object itself. +private: + Module &M; + CallGraph CG; + + // Internal data structures which are being built-up by the member function - + // `initialize()` after the class object is constructed. +private: + SmallPtrSet Kernels; + SmallPtrSet LDSGlobals; + SmallPtrSet AddressTakenSet; + ValueMap LDSToFunction; + ValueMap> FunctionToLDS; + ValueMap> KernelToCallee; + ValueMap> KernelToLDS; + + // Helper private member functions. +private: + // Collect functions whose address is taken by excluding kernels and AMDGPU + // specific library functions. + void collectAddressTakenFunctions() { + auto *ExternalCallingNode = CG.getExternalCallingNode(); + assertCallGraphNodePtrIsNonNull(ExternalCallingNode); + + for (auto GI = ExternalCallingNode->begin(), + GE = ExternalCallingNode->end(); + GI != GE; ++GI) { + auto *CGN = GI->second; + assertCallGraphNodePtrIsNonNull(CGN); + auto *F = CGN->getFunction(); + if (!F || F->getName().startswith("llvm.amdgcn.") || isKernel(F)) + continue; + AddressTakenSet.insert(CGN); + } + } + + // Filter out unhanlded kernels. Unhandled kernels are those which do not have + // any function local LDS to be lowered w.r.t them. + void filterOutUnhandledKernels() { + SmallPtrSet ToBeRemovedKernels; + for (auto *K : Kernels) { + if (KernelToLDS.find(K) == KernelToLDS.end()) + ToBeRemovedKernels.insert(K); + } + + for (auto *K : ToBeRemovedKernels) + Kernels.erase(K); + } + + // Filter out all LDS which are not handled, for example, all those which are + // defined within kernels. + void filterOutUnhandledLDS() { + SmallPtrSet ToBeRemovedLDSList; + for (auto *LDS : LDSGlobals) + if (LDSToFunction.find(LDS) == LDSToFunction.end()) + ToBeRemovedLDSList.insert(LDS); + for (auto *LDS : ToBeRemovedLDSList) + LDSGlobals.erase(LDS); + } + + // Remove all function entries within `FunctionToLDS` map which are unused, + // that is, all those functions which are not called from any of the kernels. + void filterFunctionToLDSMap(std::set &ActiveEndCallees) { + std::set ToBeRemovedFunctions; + for (auto FI = FunctionToLDS.begin(), FE = FunctionToLDS.end(); FI != FE; + ++FI) + if (ActiveEndCallees.find(FI->first) == ActiveEndCallees.end()) + ToBeRemovedFunctions.insert(FI->first); + for (auto *F : ToBeRemovedFunctions) + FunctionToLDS.erase(F); + } + + // Remove all function entries within `LDSToFunction` map which are unused, + // that is, all those functions which are not called from any of the kernels. + void filterLDSToFunctionMap(std::set &ActiveEndCallees) { + std::set ToBeRemovedLDSGlobals; + for (auto LI = LDSToFunction.begin(), LE = LDSToFunction.end(); LI != LE; + ++LI) + if (ActiveEndCallees.find(LI->second) == ActiveEndCallees.end()) + ToBeRemovedLDSGlobals.insert(LI->first); + for (auto *LDS : ToBeRemovedLDSGlobals) + LDSToFunction.erase(LDS); + } + + // Collect together all the callees (those which define LDS within them) + // which are associated with all the kernels. + void collectTogetherAllEndCallees(std::set &ActiveEndCallees) { + for (auto KI = KernelToCallee.begin(), KE = KernelToCallee.end(); KI != KE; + ++KI) { + auto &CalleeSet = KI->second; + for (auto *Callee : CalleeSet) + ActiveEndCallees.insert(Callee); + } + } + + // Private member functions which build different data structures. +private: + // Associate current kernel K with LDS set which are supposed to be lowered + // w.r.t K. + void pairUpKernelWithLDSList(Function *K); + + // Associate each kernel K with LDS set which are supposed to be lowered w.r.t + // K. + void pairUpKernelWithLDSList() { + for (auto *K : Kernels) + pairUpKernelWithLDSList(K); + } + + // There might exist functions with LDS defined within them, but without a + // call graph path from any of the kernels. Filter out such functions and + // associated LDS. + void filterOutUnusedFunctions(); + + // The call site associated with `CGNode` is a "direct call site", and the + // information about the corresponding callee, say, `Callee` is available. + // Check if `Callee` defines LDS variables within it, if so, add it to + // `CalleeSet`, and push callees of `Callee` to `CGNodeStack` to continue the + // DFS search. + void handleDirectCallSite(CallGraphNode *CGNode, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet); + + // The call site `CB` associated with the call graph node `CGNode` is an + // "indirect call site". Depending on whether the metadata `!callee` is + // available at `CB` or not, we need to handle it accordingly. + void handleIndirectCallSite(CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet); + + // Handle call site `CB` depending on whether it is a direct or an indirect + // call site, return true if an indirect call site is being handled. + bool handleCallSite(CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet); + + // Traverse `CallGraph` starting from the `CallGraphNode` which is associated + // with the kernel `K` in DFS manner and collect all the callees which define + // LDS variable(s). + void pairUpKernelWithCalleeList(Function *K); + + // Associate kernels with callees which define LDS and there exist call graph + // paths from kernels to these callees. + void pairUpKernelWithCalleeList() { + for (auto *K : Kernels) + pairUpKernelWithCalleeList(K); + } + + // Create a reverse map from function to LDS set which maps a given function F + // to a set of LDS which are defined within F. + void createFunctionToLDSMap(); + + // Recursively visit user list of LDS and find the function within which the + // `LDS` is defined, and this function should always be successfully found. + void pairUpLDSWithItsAssociatedFunction(GlobalVariable *LDS); + + // Pair up each LDS with the function within which the LDS is defined. + void pairUpLDSWithItsAssociatedFunction() { + for (auto *LDS : LDSGlobals) + pairUpLDSWithItsAssociatedFunction(LDS); + } + + // Collect all the amdgpu kernels defined within the current module. + bool collectKernels(); + + // Collect all the (static) LDS globals defined within the current module. + bool collectLDSGlobals(); + + // Build necessary data structures which will later aid the lowering process. + bool initialize(); +}; + +// Associate current kernel K with LDS set which are supposed to be lowered +// w.r.t K. +void LowerFunctionLocalLDSImpl::pairUpKernelWithLDSList(Function *K) { + std::set LDSSet; + auto Callees = KernelToCallee[K]; + + for (auto *Callee : Callees) { + if (FunctionToLDS.find(Callee) == FunctionToLDS.end()) + continue; + std::set CalleeLDSSet = FunctionToLDS[Callee]; + for (auto *CalleeLDS : CalleeLDSSet) + LDSSet.insert(CalleeLDS); + } + + if (!LDSSet.empty()) + KernelToLDS[K] = LDSSet; +} + +// There might exist functions with LDS defined within them, but without a call +// graph path from any of the kernels. Filter out such functions and associated +// LDS. +void LowerFunctionLocalLDSImpl::filterOutUnusedFunctions() { + // Collect together all the callees (those which define LDS within them) which + // are associated with all the kernels. + std::set ActiveEndCallees; + collectTogetherAllEndCallees(ActiveEndCallees); + + // Remove all function entries within `FunctionToLDS` map which are unused, + // that is, not called from any of the kernels. + filterFunctionToLDSMap(ActiveEndCallees); + + // Remove all function entries within `LDSToFunction` map which are unused, + // that is all those functions which are not called from any of the kernels. + filterLDSToFunctionMap(ActiveEndCallees); + + // Finally, further filter `LDSGlobals` data structure by removing all those + // LDS which have become inactive because of above filtering process. + filterOutUnhandledLDS(); +} + +// The call site associated with `CGNode` is a "direct call site", and the +// information about the corresponding callee, say, `Callee` is available. +// Check if `Callee` defines LDS variables within it, if so, add it to +// `CalleeSet`, and push callees of `Callee` to `CGNodeStack` to continue the +// DFS search. +void LowerFunctionLocalLDSImpl::handleDirectCallSite( + CallGraphNode *CGNode, SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet) { + auto *Callee = CGNode->getFunction(); + assert(Callee && "Exptected a valid callee associated with call site.\n"); + if (Callee->isDeclaration()) + return; + if (FunctionToLDS.find(Callee) != FunctionToLDS.end()) + CalleeSet.insert(Callee); + pushCallGraphNodes(CGNode, CGNodeStack, CallBaseStack); +} + +// The call site `CB` associated with the call graph node `CGNode` is an +// "indirect call site". Depending on whether the metadata `!callee` is +// available at `CB` or not, we need to handle it accordingly. +void LowerFunctionLocalLDSImpl::handleIndirectCallSite( + CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet) { + if (auto *MD = CB->getMetadata(LLVMContext::MD_callees)) { + // The metadata "!callee" is available at the indirect call site `CB`, which + // means, all the potential target callees for the call site `CB` is + // successfully resolved at compile time. So, push them into stack so that + // they will be handled just like direct callees when they are eventually + // poped out. + for (const auto &Op : MD->operands()) { + auto *CGN = CG[mdconst::extract_or_null(Op)]; + assertCallGraphNodePtrIsNonNull(CGN); + assert(CGN->getFunction() && "Expected the definition of the function " + "included within !callee metadata.\n"); + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } + } else { + // The metadata "!callee" is *NOT* available at the indirect call site `CB`, + // which means, `CB` has *NO* information about potential target callees. + // The simplest possible *SAFE* assumption that we can make here is to + // consider all those "address taken" functions whose singature matches with + // that of the call site `CB`, and assume that all these signature matched + // "address taken" functions are possible potential callees. Thus, push all + // the signature matched "address taken" functions into stack so that they + // will be handled just like direct callees when they are eventually poped + // out. + auto *CBFTy = CB->getFunctionType(); + for (auto *CGN : AddressTakenSet) { + auto *F = CGN->getFunction(); + assert(F && "Expected the definition of address taken function.\n"); + auto *ADFTy = F->getFunctionType(); + if (ADFTy == CBFTy) { + CGNodeStack.push_back(CGN); + CallBaseStack.push_back(CB); + } + } + } +} + +// Handle the call site `CB` depending on whether it is a direct or an indirect +// call site, return true if an indirect call site is being handled. +bool LowerFunctionLocalLDSImpl::handleCallSite( + CallGraphNode *CGNode, CallBase *CB, + SmallVectorImpl &CGNodeStack, + SmallVectorImpl &CallBaseStack, + std::set &CalleeSet) { + bool IndirectCallSite = false; + if (CGNode->getFunction()) { + handleDirectCallSite(CGNode, CGNodeStack, CallBaseStack, CalleeSet); + } else { + // FIXME: Is it guaranteed that it represents indirect call sites? looks + // like not. To fix it as soon as possible. + handleIndirectCallSite(CGNode, CB, CGNodeStack, CallBaseStack, CalleeSet); + IndirectCallSite = true; + } + return IndirectCallSite; +} + +// Traverse `CallGraph` starting from the `CallGraphNode` which is associated +// with the kernel `K` in DFS manner and collect all the callees which define +// LDS variable(s). +void LowerFunctionLocalLDSImpl::pairUpKernelWithCalleeList(Function *K) { + auto *KernCGNode = CG[K]; + SmallVector CGNodeStack; + SmallVector CallBaseStack; + SmallPtrSet Visited; + std::set CalleeSet; + + // Push the `CallGraphNode` associated with all the callees of the kernel`K` + // into into `CGNodeStack`, and the corresponding call sites into + // `CallBaseStack`. + pushCallGraphNodes(KernCGNode, CGNodeStack, CallBaseStack); + + // Continue DFS search until no more call graph nodes to handle. + while (!CGNodeStack.empty()) { + assert(CGNodeStack.size() == CallBaseStack.size() && + "Stack holding CallBase pointers is currupted.\n"); + auto *CGNode = CGNodeStack.pop_back_val(); + auto *CB = CallBaseStack.pop_back_val(); + if (!Visited.insert(CGNode).second) + continue; + + if (handleCallSite(CGNode, CB, CGNodeStack, CallBaseStack, CalleeSet)) { + // The call site `CB` being handled is an inditect call site. The indirect + // call site does not bind to any particular function, and the + // `CallGraphNode` is same for all the function pointers which have same + // signature. Hence, we should *NOT* assume that `CGNode` is visited + // unlike in case of direct call site. + Visited.erase(CGNode); + } + } + + assert(CallBaseStack.empty() && "Stack holding CallBase pointers is " + "currupted.\n"); + + KernelToCallee[K] = CalleeSet; +} + +// Create a reverse map from function to LDS set which maps a given function F +// to a set of LDS which are defined within F. +void LowerFunctionLocalLDSImpl::createFunctionToLDSMap() { + for (auto LI = LDSToFunction.begin(), LE = LDSToFunction.end(); LI != LE; + ++LI) { + auto *LDS = LI->first; + auto *F = LI->second; + auto FI = FunctionToLDS.find(F); + if (FI == FunctionToLDS.end()) { + std::set LDSSet; + LDSSet.insert(LDS); + FunctionToLDS[F] = LDSSet; + } else + FunctionToLDS[F].insert(LDS); + } +} + +// Recursively visit user list of LDS and find the function within which the +// `LDS` is defined, and this function should always be successfully found. +void LowerFunctionLocalLDSImpl::pairUpLDSWithItsAssociatedFunction( + GlobalVariable *LDS) { + assert(!LDS->user_empty() && + "LDS user list cannot be empty since it must have been successfully " + "defined within either kernel or function"); + + SmallVector UserStack; + SmallPtrSet Visited; + + for (auto *U : LDS->users()) + UserStack.push_back(U); + + while (!UserStack.empty()) { + auto *U = UserStack.pop_back_val(); + if (!Visited.insert(U).second) + continue; + + auto *I = dyn_cast(U); + + // `U` is either a `ConstantExpr` or a `Const` which is nested within an + // `Instruction`. Push-back users of `U`, and continue further exploring + // the stack until an `Instruction` is found. + if (!I) { + for (auto *UU : U->users()) + UserStack.push_back(UU); + continue; + } + + // We are only interested in LDS defined within function. Hence a new + // entry within `LDSToFunction` map will be created only if `F` is a + // function. + auto *F = I->getParent()->getParent(); + if (!isKernel(F)) + LDSToFunction[LDS] = F; + + return; + } + + llvm_unreachable("Control is not expected to reach this point"); +} + +// Collect all the amdgpu kernels defined within the current module. +bool LowerFunctionLocalLDSImpl::collectKernels() { + for (auto &F : M.functions()) { + if (isKernel(&F) && !F.isDeclaration()) + Kernels.insert(&F); + } + + if (Kernels.empty()) + return false; + + return true; +} + +// Collect all the (static) LDS globals defined within the current module. +bool LowerFunctionLocalLDSImpl::collectLDSGlobals() { + for (auto &GV : M.globals()) { + if (GV.getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS && + !GV.hasExternalLinkage()) + LDSGlobals.insert(&GV); + } + + if (LDSGlobals.empty()) + return false; + + return true; +} + +// Build necessary data structures which will later aid the lowering process. +bool LowerFunctionLocalLDSImpl::initialize() { + // No LDS globals defined within the module? then, nothing to do. + if (!collectLDSGlobals()) + return false; + + // No kernels defined within the module? then, nothing to do. + if (!collectKernels()) + return false; + + // Collect the functions whose address is taken by excluding kernels and + // AMDGPU specific library functions. + collectAddressTakenFunctions(); + + // Associate each LDS with the function within which the LDS is defined. + pairUpLDSWithItsAssociatedFunction(); + + // Filter out all LDS which are not handled, for example, all those which are + // defined within kernels. + filterOutUnhandledLDS(); + + // Create a reverse map from function to LDS set which maps a given function + // F to a set of LDS which are defined within F. + createFunctionToLDSMap(); + + // Associate kernels with callees which define LDS and there exist call graph + // paths from kernels to these callees. + pairUpKernelWithCalleeList(); + + // There might exist functions with LDS defined within them, but without a + // call graph path from any of the kernels. Filter out such functions and + // associated LDS. + filterOutUnusedFunctions(); + + // Associate each kernel K with LDS set which are supposed to be lowered w.r.t + // K. + pairUpKernelWithLDSList(); + + // Filter out unhanlded kernels. Unhandled kernels are those which do not have + // any function local LDS to be lowered w.r.t them. + filterOutUnhandledKernels(); + + // After all the above filtering, if we left with no kernel to handle, then, + // othing to do. + if (Kernels.empty()) + return false; + + return true; +} + +// Entry-point function. +bool LowerFunctionLocalLDSImpl::lower() { + // Build necessary data structures which will be later used in the lowering + // process. + if (!initialize()) + return false; + + return false; +} + +class AMDGPULowerFunctionLocalLDS : public ModulePass { +public: + static char ID; + + AMDGPULowerFunctionLocalLDS() : ModulePass(ID) { + initializeAMDGPULowerFunctionLocalLDSPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; +}; + +} // namespace + +char AMDGPULowerFunctionLocalLDS::ID = 0; +char &llvm::AMDGPULowerFunctionLocalLDSID = AMDGPULowerFunctionLocalLDS::ID; + +INITIALIZE_PASS(AMDGPULowerFunctionLocalLDS, "amdgpu-lower-function-local-lds", + "Lower LDS Defined Within AMDGPU Non-kernel Device Function", + false /*only look at the cfg*/, false /*analysis pass*/) + +bool AMDGPULowerFunctionLocalLDS::runOnModule(Module &M) { + LowerFunctionLocalLDSImpl LDSLowerer{M}; + return LDSLowerer.lower(); +} + +ModulePass *llvm::createAMDGPULowerFunctionLocalLDSPass() { + return new AMDGPULowerFunctionLocalLDS(); +} + +PreservedAnalyses +AMDGPULowerFunctionLocalLDSPass::run(Module &M, ModuleAnalysisManager &AM) { + LowerFunctionLocalLDSImpl LDSLowerer{M}; + LDSLowerer.lower(); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h @@ -35,6 +35,7 @@ static bool EnableLateStructurizeCFG; static bool EnableFunctionCalls; static bool EnableFixedFunctionABI; + static bool EnableFunctionLocalLDSLowering; AMDGPUTargetMachine(const Target &T, const Triple &TT, StringRef CPU, StringRef FS, TargetOptions Options, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -193,6 +193,13 @@ cl::desc("Enable workarounds for the StructurizeCFG pass"), cl::init(true), cl::Hidden); +static cl::opt EnableFunctionLocalLDSLowering( + "amdgpu-enable-function-local-lds-lowering", + cl::desc( + "Enable lowering of LDS defined within non-kernel device function"), + cl::location(AMDGPUTargetMachine::EnableFunctionLocalLDSLowering), + cl::init(true), cl::Hidden); + extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAMDGPUTarget() { // Register the target RegisterTargetMachine X(getTheAMDGPUTarget()); @@ -259,6 +266,7 @@ initializeGCNRegBankReassignPass(*PR); initializeGCNNSAReassignPass(*PR); initializeSIAddIMGInitPass(*PR); + initializeAMDGPULowerFunctionLocalLDSPass(*PR); } static std::unique_ptr createTLOF(const Triple &TT) { @@ -388,6 +396,7 @@ bool AMDGPUTargetMachine::EnableLateStructurizeCFG = false; bool AMDGPUTargetMachine::EnableFunctionCalls = false; bool AMDGPUTargetMachine::EnableFixedFunctionABI = false; +bool AMDGPUTargetMachine::EnableFunctionLocalLDSLowering = false; AMDGPUTargetMachine::~AMDGPUTargetMachine() = default; @@ -501,6 +510,10 @@ PM.addPass(AMDGPUAlwaysInlinePass()); return true; } + if (PassName == "amdgpu-lower-function-local-lds") { + PM.addPass(AMDGPULowerFunctionLocalLDSPass()); + return true; + } return false; }); PB.registerPipelineParsingCallback( @@ -847,6 +860,12 @@ disablePass(&FuncletLayoutID); disablePass(&PatchableFunctionID); + // We expect to run this pass as a first AMDGPU IR pass so that new + // instructions being added in this pass can possibly undergo further + // transformations via subsequent passes. + if (EnableFunctionLocalLDSLowering) + addPass(createAMDGPULowerFunctionLocalLDSPass()); + addPass(createAMDGPUPrintfRuntimeBinding()); // This must occur before inlining, as the inliner will not look through diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt --- a/llvm/lib/Target/AMDGPU/CMakeLists.txt +++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt @@ -50,6 +50,7 @@ AMDGPUAtomicOptimizer.cpp AMDGPUCallLowering.cpp AMDGPUCodeGenPrepare.cpp + AMDGPULowerFunctionLocalLDS.cpp AMDGPUExportClustering.cpp AMDGPUFixFunctionBitcasts.cpp AMDGPUFrameLowering.cpp