Index: lib/CodeGen/CGDecl.cpp =================================================================== --- lib/CodeGen/CGDecl.cpp +++ lib/CodeGen/CGDecl.cpp @@ -1016,9 +1016,17 @@ } // A normal fixed sized variable becomes an alloca in the entry block, - // unless it's an NRVO variable. - - if (NRVO) { + // unless: + // - it's an NRVO variable. + // - we are compiling OpenMP and it's an OpenMP local variable. + + Address OpenMPLocalAddr = + getLangOpts().OpenMP + ? CGM.getOpenMPRuntime().getAddressOfLocalVariable(*this, &D) + : Address::invalid(); + if (getLangOpts().OpenMP && OpenMPLocalAddr.isValid()) { + address = OpenMPLocalAddr; + } else if (NRVO) { // The named return value optimization: allocate this variable in the // return slot, so that we can elide the copy when returning this // variable (C++0x [class.copy]p34). @@ -1820,9 +1828,18 @@ pushDestroy(QualType::DK_cxx_destructor, DeclPtr, Ty); } } else { - // Otherwise, create a temporary to hold the value. - DeclPtr = CreateMemTemp(Ty, getContext().getDeclAlign(&D), - D.getName() + ".addr"); + // Check if the parameter address is controlled by OpenMP runtime. + Address OpenMPLocalAddr = + getLangOpts().OpenMP + ? CGM.getOpenMPRuntime().getAddressOfLocalVariable(*this, &D) + : Address::invalid(); + if (getLangOpts().OpenMP && OpenMPLocalAddr.isValid()) { + DeclPtr = OpenMPLocalAddr; + } else { + // Otherwise, create a temporary to hold the value. + DeclPtr = CreateMemTemp(Ty, getContext().getDeclAlign(&D), + D.getName() + ".addr"); + } DoStore = true; } Index: lib/CodeGen/CGOpenMPRuntime.h =================================================================== --- lib/CodeGen/CGOpenMPRuntime.h +++ lib/CodeGen/CGOpenMPRuntime.h @@ -675,7 +675,7 @@ /// \brief Cleans up references to the objects in finished function. /// - void functionFinished(CodeGenFunction &CGF); + virtual void functionFinished(CodeGenFunction &CGF); /// \brief Emits code for parallel or serial call of the \a OutlinedFn with /// variables captured in a record which address is stored in \a @@ -1362,6 +1362,14 @@ emitOutlinedFunctionCall(CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, ArrayRef Args = llvm::None) const; + + /// Emits OpenMP-specific function prolog. + /// Required for device constructs. + virtual void emitFunctionProlog(CodeGenFunction &CGF, const Decl *D) {} + + /// Gets the OpenMP-specific address of the local variable. + virtual Address getAddressOfLocalVariable(CodeGenFunction &CGF, + const VarDecl *VD); }; /// Class supports emissionof SIMD-only code. Index: lib/CodeGen/CGOpenMPRuntime.cpp =================================================================== --- lib/CodeGen/CGOpenMPRuntime.cpp +++ lib/CodeGen/CGOpenMPRuntime.cpp @@ -8019,6 +8019,11 @@ return CGF.GetAddrOfLocalVar(NativeParam); } +Address CGOpenMPRuntime::getAddressOfLocalVariable(CodeGenFunction &CGF, + const VarDecl *VD) { + return Address::invalid(); +} + llvm::Value *CGOpenMPSIMDRuntime::emitParallelOutlinedFunction( const OMPExecutableDirective &D, const VarDecl *ThreadIDVar, OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) { Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.h =================================================================== --- lib/CodeGen/CGOpenMPRuntimeNVPTX.h +++ lib/CodeGen/CGOpenMPRuntimeNVPTX.h @@ -288,6 +288,14 @@ CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, ArrayRef Args = llvm::None) const override; + /// Emits OpenMP-specific function prolog. + /// Required for device constructs. + void emitFunctionProlog(CodeGenFunction &CGF, const Decl *D) override; + + /// Gets the OpenMP-specific address of the local variable. + Address getAddressOfLocalVariable(CodeGenFunction &CGF, + const VarDecl *VD) override; + /// Target codegen is specialized based on two programming models: the /// 'generic' fork-join model of OpenMP, and a more GPU efficient 'spmd' /// model for constructs like 'target parallel' that support it. @@ -299,12 +307,39 @@ Unknown, }; + /// Cleans up references to the objects in finished function. + /// + void functionFinished(CodeGenFunction &CGF) override; + private: // Track the execution mode when codegening directives within a target // region. The appropriate mode (generic/spmd) is set on entry to the // target region and used by containing directives such as 'parallel' // to emit optimized code. ExecutionMode CurrentExecutionMode; + + /// Map between an outlined function and its wrapper. + llvm::DenseMap WrapperFunctionsMap; + + /// Emit function which wraps the outline parallel region + /// and controls the parameters which are passed to this function. + /// The wrapper ensures that the outlined function is called + /// with the correct arguments when data is shared. + llvm::Function * + createDataSharingWrapper(llvm::Function *OutlinedParallelFn, + const OMPExecutableDirective &D); + + /// The map of local variables to their addresses in the global memory. + typedef llvm::MapVector< + const Decl *, std::pair> DeclToAddrMapTy; + /// Maps the function to the list of the globalized variables with their + /// addresses. + llvm::DenseMap< + llvm::Function *, std::pair> + FunctionGlobalizedDecls; + /// Map from function to global record pointer. + llvm::DenseMap + FunctionToGlobalRecPtr; }; } // CodeGen namespace. Index: lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp =================================================================== --- lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -13,9 +13,11 @@ //===----------------------------------------------------------------------===// #include "CGOpenMPRuntimeNVPTX.h" -#include "clang/AST/DeclOpenMP.h" #include "CodeGenFunction.h" +#include "clang/AST/DeclOpenMP.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtVisitor.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace clang; using namespace CodeGen; @@ -70,7 +72,18 @@ /// index, int32_t width, int32_t reduce)) OMPRTL_NVPTX__kmpc_teams_reduce_nowait, /// \brief Call to __kmpc_nvptx_end_reduce_nowait(int32_t global_tid); - OMPRTL_NVPTX__kmpc_end_reduce_nowait + OMPRTL_NVPTX__kmpc_end_reduce_nowait, + /// \brief Call to void* __kmpc_data_sharing_push_stack(size_t size); + OMPRTL_NVPTX__kmpc_data_sharing_push_stack, + /// \brief Call to void __kmpc_data_sharing_pop_stack(void *a); + OMPRTL_NVPTX__kmpc_data_sharing_pop_stack, + /// \brief Call to void __kmpc_begin_sharing_variables(void ***args, + /// size_t n_args); + OMPRTL_NVPTX__kmpc_begin_sharing_variables, + /// \brief Call to void __kmpc_end_sharing_variables(); + OMPRTL_NVPTX__kmpc_end_sharing_variables, + /// \brief Call to void __kmpc_get_shared_variables(void ***GlobalArgs) + OMPRTL_NVPTX__kmpc_get_shared_variables, }; /// Pre(post)-action for different OpenMP constructs specialized for NVPTX. @@ -149,6 +162,246 @@ /// barrier. NB_Parallel = 1, }; + +/// Get the list of variables that can escape their declaration context. +class CheckVarsEscapingDeclContext final + : public ConstStmtVisitor { + CodeGenFunction &CGF; + llvm::SetVector EscapedDecls; + llvm::SmallPtrSet IgnoredDecls; + bool AllEscaped = false; + RecordDecl *GlobalizedRD = nullptr; + llvm::SmallDenseMap MappedDeclsFields; + + void markAsEscaped(const ValueDecl *VD) { + if (IgnoredDecls.count(VD) || + (CGF.CapturedStmtInfo && + CGF.CapturedStmtInfo->lookup(cast(VD)))) + return; + EscapedDecls.insert(VD); + } + + void VisitValueDecl(const ValueDecl *VD) { + if (VD->getType()->isLValueReferenceType()) { + markAsEscaped(VD); + if (const auto *VarD = dyn_cast(VD)) { + if (!isa(VarD) && VarD->hasInit()) { + const bool SavedAllEscaped = AllEscaped; + AllEscaped = true; + Visit(VarD->getInit()); + AllEscaped = SavedAllEscaped; + } + } + } + } + void VisitOpenMPCapturedStmt(const CapturedStmt *S) { + if (!S) + return; + for (const auto &C : S->captures()) { + if (C.capturesVariable() && !C.capturesVariableByCopy()) { + const ValueDecl *VD = C.getCapturedVar(); + markAsEscaped(VD); + if (isa(VD)) + VisitValueDecl(VD); + } + } + } + + typedef std::pair VarsDataTy; + static bool stable_sort_comparator(const VarsDataTy P1, + const VarsDataTy P2) { + return P1.first > P2.first; + } + + void buildRecordForGlobalizedVars() { + assert(!GlobalizedRD && + "Record for globalized variables is built already."); + if (EscapedDecls.empty()) + return; + ASTContext &C = CGF.getContext(); + SmallVector GlobalizedVars; + for (const auto *D : EscapedDecls) + GlobalizedVars.emplace_back(C.getDeclAlign(D), D); + std::stable_sort(GlobalizedVars.begin(), GlobalizedVars.end(), + stable_sort_comparator); + // Build struct _globalized_locals_ty { + // /* globalized vars */ + // }; + GlobalizedRD = C.buildImplicitRecord("_globalized_locals_ty"); + GlobalizedRD->startDefinition(); + for (const auto &Pair : GlobalizedVars) { + const ValueDecl *VD = Pair.second; + QualType Type = VD->getType(); + if (Type->isLValueReferenceType()) + Type = C.getPointerType(Type.getNonReferenceType()); + else + Type = Type.getNonReferenceType(); + SourceLocation Loc = VD->getLocation(); + auto *Field = FieldDecl::Create( + C, GlobalizedRD, Loc, Loc, VD->getIdentifier(), Type, + C.getTrivialTypeSourceInfo(Type, SourceLocation()), + /*BW=*/nullptr, /*Mutable=*/false, + /*InitStyle=*/ICIS_NoInit); + Field->setAccess(AS_public); + GlobalizedRD->addDecl(Field); + if (VD->hasAttrs()) { + for (specific_attr_iterator I(VD->getAttrs().begin()), + E(VD->getAttrs().end()); + I != E; ++I) + Field->addAttr(*I); + } + MappedDeclsFields.try_emplace(VD, Field); + } + GlobalizedRD->completeDefinition(); + } + +public: + CheckVarsEscapingDeclContext(CodeGenFunction &CGF, + ArrayRef IgnoredDecls) + : CGF(CGF), IgnoredDecls(IgnoredDecls.begin(), IgnoredDecls.end()) {} + virtual ~CheckVarsEscapingDeclContext() = default; + void VisitDeclStmt(const DeclStmt *S) { + if (!S) + return; + for (const auto *D : S->decls()) + if (const auto *VD = dyn_cast_or_null(D)) + VisitValueDecl(VD); + } + void VisitOMPExecutableDirective(const OMPExecutableDirective *D) { + if (!D) + return; + if (D->hasAssociatedStmt()) { + if (const auto *S = + dyn_cast_or_null(D->getAssociatedStmt())) + VisitOpenMPCapturedStmt(S); + } + } + void VisitCapturedStmt(const CapturedStmt *S) { + if (!S) + return; + for (const auto &C : S->captures()) { + if (C.capturesVariable() && !C.capturesVariableByCopy()) { + const ValueDecl *VD = C.getCapturedVar(); + markAsEscaped(VD); + if (isa(VD)) + VisitValueDecl(VD); + } + } + } + void VisitLambdaExpr(const LambdaExpr *E) { + if (!E) + return; + for (const auto &C : E->captures()) { + if (C.capturesVariable()) { + if (C.getCaptureKind() == LCK_ByRef) { + const ValueDecl *VD = C.getCapturedVar(); + markAsEscaped(VD); + if (E->isInitCapture(&C) || isa(VD)) + VisitValueDecl(VD); + } + } + } + } + void VisitBlockExpr(const BlockExpr *E) { + if (!E) + return; + for (const auto &C : E->getBlockDecl()->captures()) { + if (C.isByRef()) { + const VarDecl *VD = C.getVariable(); + markAsEscaped(VD); + if (isa(VD) || VD->isInitCapture()) + VisitValueDecl(VD); + } + } + } + void VisitCallExpr(const CallExpr *E) { + if (!E) + return; + for (const Expr *Arg : E->arguments()) { + if (!Arg) + continue; + if (Arg->isLValue()) { + const bool SavedAllEscaped = AllEscaped; + AllEscaped = true; + Visit(Arg); + AllEscaped = SavedAllEscaped; + } else + Visit(Arg); + } + Visit(E->getCallee()); + } + void VisitDeclRefExpr(const DeclRefExpr *E) { + if (!E) + return; + const ValueDecl *VD = E->getDecl(); + if (AllEscaped) + markAsEscaped(VD); + if (isa(VD)) + VisitValueDecl(VD); + else if (const auto *VarD = dyn_cast(VD)) + if (VarD->isInitCapture()) + VisitValueDecl(VD); + } + void VisitUnaryOperator(const UnaryOperator *E) { + if (!E) + return; + if (E->getOpcode() == UO_AddrOf) { + const bool SavedAllEscaped = AllEscaped; + AllEscaped = true; + Visit(E->getSubExpr()); + AllEscaped = SavedAllEscaped; + } else + Visit(E->getSubExpr()); + } + void VisitImplicitCastExpr(const ImplicitCastExpr *E) { + if (!E) + return; + if (E->getCastKind() == CK_ArrayToPointerDecay) { + const bool SavedAllEscaped = AllEscaped; + AllEscaped = true; + Visit(E->getSubExpr()); + AllEscaped = SavedAllEscaped; + } else + Visit(E->getSubExpr()); + } + void VisitExpr(const Expr *E) { + if (!E) + return; + bool SavedAllEscaped = AllEscaped; + if (!E->isLValue()) + AllEscaped = false; + for (const auto *Child : E->children()) + if (Child) + Visit(Child); + AllEscaped = SavedAllEscaped; + } + void VisitStmt(const Stmt *S) { + if (!S) + return; + for (const auto *Child : S->children()) + if (Child) + Visit(Child); + } + + const RecordDecl *getGlobalizedRecord() { + if (!GlobalizedRD) + buildRecordForGlobalizedVars(); + return GlobalizedRD; + } + + const FieldDecl *getFieldForGlobalizedVar(const ValueDecl *VD) const { + assert(GlobalizedRD && + "Record for globalized variables must be generated already."); + auto I = MappedDeclsFields.find(VD); + if (I == MappedDeclsFields.end()) + return nullptr; + return I->getSecond(); + } + + ArrayRef getEscapedDecls() const { + return EscapedDecls.getArrayRef(); + } +}; } // anonymous namespace /// Get the GPU warp size. @@ -299,6 +552,7 @@ EntryFunctionState EST; WorkerFunctionState WST(CGM); Work.clear(); + WrapperFunctionsMap.clear(); // Emit target region as a standalone region. class NVPTXPrePostActionTy : public PrePostActionTy { @@ -362,10 +616,58 @@ Bld.getInt16(/*RequiresOMPRuntime=*/1)}; CGF.EmitRuntimeCall( createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init), Args); + const auto I = FunctionGlobalizedDecls.find(CGF.CurFn); + if (I == FunctionGlobalizedDecls.end()) + return; + const RecordDecl *GlobalizedVarsRecord = I->getSecond().first; + QualType RecTy = CGM.getContext().getRecordType(GlobalizedVarsRecord); + + // Recover pointer to this function's global record. The runtime will + // handle the specifics of the allocation of the memory. + // Use actual memory size of the record including the padding + // for alignment purposes. + unsigned Alignment = + CGM.getContext().getTypeAlignInChars(RecTy).getQuantity(); + unsigned GlobalRecordSize = + CGM.getContext().getTypeSizeInChars(RecTy).getQuantity(); + GlobalRecordSize = llvm::alignTo(GlobalRecordSize, Alignment); + llvm::Value *GlobalRecordSizeArg[] = { + llvm::ConstantInt::get(CGM.SizeTy, GlobalRecordSize)}; + llvm::Value *GlobalRecValue = CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction( + OMPRTL_NVPTX__kmpc_data_sharing_push_stack), + GlobalRecordSizeArg); + auto GlobalRecCastAddr = Bld.CreatePointerBitCastOrAddrSpaceCast( + GlobalRecValue, CGF.ConvertTypeForMem(RecTy)->getPointerTo()); + FunctionToGlobalRecPtr.try_emplace(CGF.CurFn, GlobalRecValue); + + // Emit the "global alloca" which is a GEP from the global declaration record + // using the pointer returned by the runtime. + LValue Base = + CGF.MakeNaturalAlignPointeeAddrLValue(GlobalRecCastAddr, RecTy); + auto &Res = I->getSecond().second; + for (auto &Rec : Res) { + const FieldDecl *FD = Rec.second.first; + LValue VarAddr = CGF.EmitLValueForField(Base, FD); + Rec.second.second = VarAddr.getAddress(); + } } void CGOpenMPRuntimeNVPTX::emitGenericEntryFooter(CodeGenFunction &CGF, EntryFunctionState &EST) { + const auto I = FunctionGlobalizedDecls.find(CGF.CurFn); + if (I != FunctionGlobalizedDecls.end()) { + if (!CGF.HaveInsertPoint()) + return; + auto I = FunctionToGlobalRecPtr.find(CGF.CurFn); + if (I != FunctionToGlobalRecPtr.end()) { + llvm::Value *Args[] = {I->getSecond()}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction( + OMPRTL_NVPTX__kmpc_data_sharing_pop_stack), Args); + } + } + if (!EST.ExitBB) EST.ExitBB = CGF.createBasicBlock(".exit"); @@ -554,14 +856,12 @@ // Execute this outlined function. CGF.EmitBlock(ExecuteFNBB); - // Insert call to work function. - // FIXME: Pass arguments to outlined function from master thread. - auto *Fn = cast(W); - Address ZeroAddr = - CGF.CreateDefaultAlignTempAlloca(CGF.Int32Ty, /*Name=*/".zero.addr"); - CGF.InitTempAlloca(ZeroAddr, CGF.Builder.getInt32(/*C=*/0)); - llvm::Value *FnArgs[] = {ZeroAddr.getPointer(), ZeroAddr.getPointer()}; - emitCall(CGF, Fn, FnArgs); + // Insert call to work function via shared wrapper. The shared + // wrapper takes two arguments: + // - the parallelism level; + // - the master thread ID; + emitCall(CGF, W, {Bld.getInt16(/*ParallelLevel=*/0), + getMasterThreadID(CGF)}); // Go to end of parallel region. CGF.EmitBranch(TerminateBB); @@ -630,8 +930,7 @@ case OMPRTL_NVPTX__kmpc_kernel_prepare_parallel: { /// Build void __kmpc_kernel_prepare_parallel( /// void *outlined_function, int16_t IsOMPRuntimeInitialized); - llvm::Type *TypeParams[] = {CGM.Int8PtrTy, - CGM.Int16Ty}; + llvm::Type *TypeParams[] = {CGM.Int8PtrTy, CGM.Int16Ty}; llvm::FunctionType *FnTy = llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_prepare_parallel"); @@ -769,6 +1068,48 @@ FnTy, /*Name=*/"__kmpc_nvptx_end_reduce_nowait"); break; } + case OMPRTL_NVPTX__kmpc_data_sharing_push_stack: { + // Build void *__kmpc_data_sharing_push_stack(size_t size); + llvm::Type *TypeParams[] = {CGM.SizeTy}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidPtrTy, TypeParams, /*isVarArg=*/false); + RTLFn = CGM.CreateRuntimeFunction( + FnTy, /*Name=*/"__kmpc_data_sharing_push_stack"); + break; + } + case OMPRTL_NVPTX__kmpc_data_sharing_pop_stack: { + // Build void __kmpc_data_sharing_pop_stack(void *a); + llvm::Type *TypeParams[] = {CGM.VoidPtrTy}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg=*/false); + RTLFn = CGM.CreateRuntimeFunction( + FnTy, /*Name=*/"__kmpc_data_sharing_pop_stack"); + break; + } + case OMPRTL_NVPTX__kmpc_begin_sharing_variables: { + /// Build void __kmpc_begin_sharing_variables(void ***args, + /// size_t n_args); + llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy->getPointerTo(), CGM.SizeTy}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_begin_sharing_variables"); + break; + } + case OMPRTL_NVPTX__kmpc_end_sharing_variables: { + /// Build void __kmpc_end_sharing_variables(); + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, {}, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_end_sharing_variables"); + break; + } + case OMPRTL_NVPTX__kmpc_get_shared_variables: { + /// Build void __kmpc_get_shared_variables(void ***GlobalArgs); + llvm::Type *TypeParams[] = {CGM.Int8PtrPtrTy->getPointerTo()}; + llvm::FunctionType *FnTy = + llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); + RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_get_shared_variables"); + break; + } } return RTLFn; } @@ -859,8 +1200,16 @@ llvm::Value *CGOpenMPRuntimeNVPTX::emitParallelOutlinedFunction( const OMPExecutableDirective &D, const VarDecl *ThreadIDVar, OpenMPDirectiveKind InnermostKind, const RegionCodeGenTy &CodeGen) { - return CGOpenMPRuntime::emitParallelOutlinedFunction(D, ThreadIDVar, - InnermostKind, CodeGen); + auto *OutlinedFun = cast( + CGOpenMPRuntime::emitParallelOutlinedFunction( + D, ThreadIDVar, InnermostKind, CodeGen)); + if (!isInSpmdExecutionMode()) { + llvm::Function *WrapperFun = + createDataSharingWrapper(OutlinedFun, D); + WrapperFunctionsMap[OutlinedFun] = WrapperFun; + } + + return OutlinedFun; } llvm::Value *CGOpenMPRuntimeNVPTX::emitTeamsOutlinedFunction( @@ -912,16 +1261,61 @@ CodeGenFunction &CGF, SourceLocation Loc, llvm::Value *OutlinedFn, ArrayRef CapturedVars, const Expr *IfCond) { llvm::Function *Fn = cast(OutlinedFn); + llvm::Function *WFn = WrapperFunctionsMap[Fn]; + + assert(WFn && "Wrapper function does not exist!"); - auto &&L0ParallelGen = [this, Fn](CodeGenFunction &CGF, PrePostActionTy &) { + // Force inline this outlined function at its call site. + Fn->setLinkage(llvm::GlobalValue::InternalLinkage); + + auto &&L0ParallelGen = [this, WFn, &CapturedVars](CodeGenFunction &CGF, + PrePostActionTy &) { CGBuilderTy &Bld = CGF.Builder; - // TODO: Optimize runtime initialization. - llvm::Value *Args[] = {Bld.CreateBitOrPointerCast(Fn, CGM.Int8PtrTy), - /*RequiresOMPRuntime=*/Bld.getInt16(1)}; - CGF.EmitRuntimeCall( - createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), - Args); + llvm::Value *ID = Bld.CreateBitOrPointerCast(WFn, CGM.Int8PtrTy); + + // There's somehting to share. + if (!CapturedVars.empty()) { + // Prepare for parallel region. Indicate the outlined function. + Address SharedArgs = + CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy, + "shared_arg_refs"); + llvm::Value *SharedArgsPtr = SharedArgs.getPointer(); + + // Handle the sharing of vaiables using global memory. + // Prepare for parallel region. Indicate the outlined function. + llvm::Value *Args[] = {ID, /*RequiresOMPRuntime=*/Bld.getInt16(1)}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), + Args); + + llvm::Value *DataSharingArgs[] = {SharedArgsPtr, + llvm::ConstantInt::get(CGM.SizeTy, CapturedVars.size())}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_begin_sharing_variables), + DataSharingArgs); + + // Store variable address in a list of references to pass to workers. + unsigned Idx = 0; + ASTContext &Ctx = CGF.getContext(); + for (llvm::Value *V : CapturedVars) { + Address Dst = Bld.CreateConstInBoundsGEP( + CGF.EmitLoadOfPointer(SharedArgs, + Ctx.getPointerType( + Ctx.getPointerType(Ctx.VoidPtrTy)).castAs()), + Idx, CGF.getPointerSize()); + llvm::Value *PtrV = Bld.CreateBitCast(V, CGF.VoidPtrTy); + CGF.EmitStoreOfScalar(PtrV, Dst, /*Volatile=*/false, + Ctx.getPointerType(Ctx.VoidPtrTy)); + Idx++; + } + } else { + // TODO: Optimize runtime initialization and pass in correct value. + llvm::Value *Args[] = {ID, /*RequiresOMPRuntime=*/Bld.getInt16(1)}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_prepare_parallel), + Args); + } // Activate workers. This barrier is used by the master to signal // work for the workers. @@ -935,8 +1329,12 @@ // The master waits at this barrier until all workers are done. syncCTAThreads(CGF); + if (!CapturedVars.empty()) + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_end_sharing_variables)); + // Remember for post-processing in worker loop. - Work.emplace_back(Fn); + Work.emplace_back(WFn); }; auto *RTLoc = emitUpdateLocation(CGF, Loc); @@ -2340,3 +2738,154 @@ } CGOpenMPRuntime::emitOutlinedFunctionCall(CGF, Loc, OutlinedFn, TargetArgs); } + +/// Emit function which wraps the outline parallel region +/// and controls the arguments which are passed to this function. +/// The wrapper ensures that the outlined function is called +/// with the correct arguments when data is shared. +llvm::Function *CGOpenMPRuntimeNVPTX::createDataSharingWrapper( + llvm::Function *OutlinedParallelFn, const OMPExecutableDirective &D) { + ASTContext &Ctx = CGM.getContext(); + const auto &CS = *cast(D.getAssociatedStmt()); + + // Create a function that takes as argument the source thread. + FunctionArgList WrapperArgs; + QualType Int16QTy = + Ctx.getIntTypeForBitwidth(/*DestWidth=*/16, /*Signed=*/false); + QualType Int32QTy = + Ctx.getIntTypeForBitwidth(/*DestWidth=*/32, /*Signed=*/false); + ImplicitParamDecl ParallelLevelArg(Ctx, Int16QTy, ImplicitParamDecl::Other); + ImplicitParamDecl WrapperArg(Ctx, Int32QTy, ImplicitParamDecl::Other); + WrapperArgs.emplace_back(&ParallelLevelArg); + WrapperArgs.emplace_back(&WrapperArg); + + auto &CGFI = + CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, WrapperArgs); + + auto *Fn = llvm::Function::Create( + CGM.getTypes().GetFunctionType(CGFI), llvm::GlobalValue::InternalLinkage, + OutlinedParallelFn->getName() + "_wrapper", &CGM.getModule()); + CGM.SetInternalFunctionAttributes(/*D=*/nullptr, Fn, CGFI); + Fn->setLinkage(llvm::GlobalValue::InternalLinkage); + + CodeGenFunction CGF(CGM, /*suppressNewContext=*/true); + CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, Fn, CGFI, WrapperArgs); + + const auto *RD = CS.getCapturedRecordDecl(); + auto CurField = RD->field_begin(); + + // Get the array of arguments. + SmallVector Args; + + // TODO: suppport SIMD and pass actual values + Args.emplace_back(llvm::ConstantPointerNull::get( + CGM.Int32Ty->getPointerTo())); + Args.emplace_back(llvm::ConstantPointerNull::get( + CGM.Int32Ty->getPointerTo())); + + CGBuilderTy &Bld = CGF.Builder; + auto CI = CS.capture_begin(); + + // Use global memory for data sharing. + // Handle passing of global args to workers. + Address GlobalArgs = + CGF.CreateDefaultAlignTempAlloca(CGF.VoidPtrPtrTy, + "global_args"); + llvm::Value *GlobalArgsPtr = GlobalArgs.getPointer(); + llvm::Value *DataSharingArgs[] = {GlobalArgsPtr}; + CGF.EmitRuntimeCall( + createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_get_shared_variables), + DataSharingArgs); + + // Retrieve the shared variables from the list of references returned + // by the runtime. Pass the variables to the outlined function. + ASTContext &CGFContext = CGF.getContext(); + for (unsigned I = 0, E = CS.capture_size(); I < E; ++I, ++CI, ++CurField) { + QualType ElemTy = CurField->getType(); + // If this is a capture by copy the element type has to be the pointer to + // the data. + if (CI->capturesVariableByCopy()) + ElemTy = CGFContext.getPointerType(ElemTy); + + Address Src = Bld.CreateConstInBoundsGEP( + CGF.EmitLoadOfPointer(GlobalArgs, + CGFContext.getPointerType( + CGFContext.getPointerType( + CGFContext.VoidPtrTy)).castAs()), + I, CGF.getPointerSize()); + Address TypedAddress = Bld.CreateBitCast( + Src, CGF.ConvertTypeForMem(CGFContext.getPointerType(ElemTy))); + llvm::Value *Arg = CGF.EmitLoadOfScalar(TypedAddress, + /*Volatile=*/false, CGFContext.getPointerType(ElemTy), + CurField->getInnerLocStart()); + Args.emplace_back(Arg); + } + + emitCall(CGF, OutlinedParallelFn, Args); + CGF.FinishFunction(); + return Fn; +} + +void CGOpenMPRuntimeNVPTX::emitFunctionProlog(CodeGenFunction &CGF, + const Decl *D) { + assert(D && "Expected function or captured|block decl."); + assert(FunctionGlobalizedDecls.count(CGF.CurFn) == 0 && + "Function is registered already."); + SmallVector IgnoredDecls; + const Stmt *Body = nullptr; + if (const auto *FD = dyn_cast(D)) { + Body = FD->getBody(); + } else if (const auto *BD = dyn_cast(D)) { + Body = BD->getBody(); + } else if (const auto *CD = dyn_cast(D)) { + Body = CD->getBody(); + if (CGF.CapturedStmtInfo->getKind() == CR_OpenMP) { + const auto *CS = dyn_cast(Body); + const auto *LastCS = CS; + while (CS && CS->getCapturedRegionKind() == CR_OpenMP) { + Body = CS->getCapturedStmt(); + LastCS = CS; + CS = dyn_cast(Body); + } + if (LastCS) { + IgnoredDecls.reserve(LastCS->capture_size()); + for (const auto &Capture : LastCS->captures()) + if (Capture.capturesVariable()) + IgnoredDecls.emplace_back(Capture.getCapturedVar()); + } + } + } + if (!Body) + return; + CheckVarsEscapingDeclContext VarChecker(CGF, IgnoredDecls); + VarChecker.Visit(Body); + const RecordDecl *GlobalizedVarsRecord = + VarChecker.getGlobalizedRecord(); + if (!GlobalizedVarsRecord) + return; + auto &Res = FunctionGlobalizedDecls.try_emplace(CGF.CurFn, + GlobalizedVarsRecord, + DeclToAddrMapTy()).first->getSecond().second; + for (const ValueDecl *VD : VarChecker.getEscapedDecls()) { + const FieldDecl *FD = VarChecker.getFieldForGlobalizedVar(VD); + Res.insert(std::make_pair(VD, + std::make_pair(FD, Address::invalid()))); + } +} + +Address CGOpenMPRuntimeNVPTX::getAddressOfLocalVariable( + CodeGenFunction &CGF, const VarDecl *VD) { + auto I = FunctionGlobalizedDecls.find(CGF.CurFn); + if (I == FunctionGlobalizedDecls.end()) + return Address::invalid(); + auto VDI = I->getSecond().second.find(VD); + if (VDI == I->getSecond().second.end()) + return Address::invalid(); + return VDI->second.second; +} + +void CGOpenMPRuntimeNVPTX::functionFinished(CodeGenFunction &CGF) { + FunctionToGlobalRecPtr.erase(CGF.CurFn); + FunctionGlobalizedDecls.erase(CGF.CurFn); + CGOpenMPRuntime::functionFinished(CGF); +} Index: lib/CodeGen/CGStmtOpenMP.cpp =================================================================== --- lib/CodeGen/CGStmtOpenMP.cpp +++ lib/CodeGen/CGStmtOpenMP.cpp @@ -589,6 +589,7 @@ /*RegisterCastedArgsOnly=*/true, CapturedStmtInfo->getHelperName()); CodeGenFunction WrapperCGF(CGM, /*suppressNewContext=*/true); + WrapperCGF.CapturedStmtInfo = CapturedStmtInfo; Args.clear(); LocalAddrs.clear(); VLASizes.clear(); Index: lib/CodeGen/CodeGenFunction.cpp =================================================================== --- lib/CodeGen/CodeGenFunction.cpp +++ lib/CodeGen/CodeGenFunction.cpp @@ -1058,6 +1058,11 @@ EmitStartEHSpec(CurCodeDecl); PrologueCleanupDepth = EHStack.stable_begin(); + + // Emit OpenMP specific initialization of the device functions. + if (getLangOpts().OpenMP && CurCodeDecl) + CGM.getOpenMPRuntime().emitFunctionProlog(*this, CurCodeDecl); + EmitFunctionProlog(*CurFnInfo, CurFn, Args); if (D && isa(D) && cast(D)->isInstance()) { Index: test/OpenMP/nvptx_parallel_codegen.cpp =================================================================== --- test/OpenMP/nvptx_parallel_codegen.cpp +++ test/OpenMP/nvptx_parallel_codegen.cpp @@ -92,20 +92,20 @@ // // CHECK: [[EXEC_PARALLEL]] // CHECK: [[WF1:%.+]] = load i8*, i8** [[OMP_WORK_FN]], - // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i32*, i32*)* [[PARALLEL_FN1:@.+]] to i8*) + // CHECK: [[WM1:%.+]] = icmp eq i8* [[WF1]], bitcast (void (i16, i32)* [[PARALLEL_FN1:@.+]]_wrapper to i8*) // CHECK: br i1 [[WM1]], label {{%?}}[[EXEC_PFN1:.+]], label {{%?}}[[CHECK_NEXT1:.+]] // // CHECK: [[EXEC_PFN1]] - // CHECK: call void [[PARALLEL_FN1]]( + // CHECK: call void [[PARALLEL_FN1]]_wrapper( // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]] // // CHECK: [[CHECK_NEXT1]] // CHECK: [[WF2:%.+]] = load i8*, i8** [[OMP_WORK_FN]], - // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i32*, i32*)* [[PARALLEL_FN2:@.+]] to i8*) + // CHECK: [[WM2:%.+]] = icmp eq i8* [[WF2]], bitcast (void (i16, i32)* [[PARALLEL_FN2:@.+]]_wrapper to i8*) // CHECK: br i1 [[WM2]], label {{%?}}[[EXEC_PFN2:.+]], label {{%?}}[[CHECK_NEXT2:.+]] // // CHECK: [[EXEC_PFN2]] - // CHECK: call void [[PARALLEL_FN2]]( + // CHECK: call void [[PARALLEL_FN2]]_wrapper( // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]] // // CHECK: [[CHECK_NEXT2]] @@ -152,13 +152,13 @@ // CHECK-DAG: [[MWS:%.+]] = call i32 @llvm.nvvm.read.ptx.sreg.warpsize() // CHECK: [[MTMP1:%.+]] = sub i32 [[MNTH]], [[MWS]] // CHECK: call void @__kmpc_kernel_init(i32 [[MTMP1]] - // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN1]] to i8*), + // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN1]]_wrapper to i8*), // CHECK: call void @llvm.nvvm.barrier0() // CHECK: call void @llvm.nvvm.barrier0() // CHECK: call void @__kmpc_serialized_parallel( // CHECK: {{call|invoke}} void [[PARALLEL_FN3:@.+]]( // CHECK: call void @__kmpc_end_serialized_parallel( - // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN2]] to i8*), + // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN2]]_wrapper to i8*), // CHECK: call void @llvm.nvvm.barrier0() // CHECK: call void @llvm.nvvm.barrier0() // CHECK-64-DAG: load i32, i32* [[REF_A]] @@ -203,7 +203,7 @@ // // CHECK: [[AWAIT_WORK]] // CHECK: call void @llvm.nvvm.barrier0() - // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]] + // CHECK: [[KPR:%.+]] = call i1 @__kmpc_kernel_parallel(i8** [[OMP_WORK_FN]], // CHECK: [[KPRB:%.+]] = zext i1 [[KPR]] to i8 // store i8 [[KPRB]], i8* [[OMP_EXEC_STATUS]], align 1 // CHECK: [[WORK:%.+]] = load i8*, i8** [[OMP_WORK_FN]], @@ -217,11 +217,11 @@ // // CHECK: [[EXEC_PARALLEL]] // CHECK: [[WF:%.+]] = load i8*, i8** [[OMP_WORK_FN]], - // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i32*, i32*)* [[PARALLEL_FN4:@.+]] to i8*) + // CHECK: [[WM:%.+]] = icmp eq i8* [[WF]], bitcast (void (i16, i32)* [[PARALLEL_FN4:@.+]]_wrapper to i8*) // CHECK: br i1 [[WM]], label {{%?}}[[EXEC_PFN:.+]], label {{%?}}[[CHECK_NEXT:.+]] // // CHECK: [[EXEC_PFN]] - // CHECK: call void [[PARALLEL_FN4]]( + // CHECK: call void [[PARALLEL_FN4]]_wrapper( // CHECK: br label {{%?}}[[TERM_PARALLEL:.+]] // // CHECK: [[CHECK_NEXT]] @@ -283,7 +283,7 @@ // CHECK: br i1 [[CMP]], label {{%?}}[[IF_THEN:.+]], label {{%?}}[[IF_ELSE:.+]] // // CHECK: [[IF_THEN]] - // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i32*, i32*)* [[PARALLEL_FN4]] to i8*), + // CHECK: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void (i16, i32)* [[PARALLEL_FN4]]_wrapper to i8*), // CHECK: call void @llvm.nvvm.barrier0() // CHECK: call void @llvm.nvvm.barrier0() // CHECK: br label {{%?}}[[IF_END:.+]]