diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1319,7 +1319,6 @@ llvm::SmallVectorImpl &) {} void CodeGenFunction::EmitOMPParallelDirective(const OMPParallelDirective &S) { - if (llvm::OpenMPIRBuilder *OMPBuilder = CGM.getOpenMPIRBuilder()) { // Check if we have any if clause associated with the directive. llvm::Value *IfCond = nullptr; @@ -2991,11 +2990,116 @@ } void CodeGenFunction::EmitOMPMasterDirective(const OMPMasterDirective &S) { + if (llvm::OpenMPIRBuilder *OMPBuilder = CGM.getOpenMPIRBuilder()) { + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + const Stmt *MasterRegionBodyStmt = + S.getInnermostCapturedStmt()->getCapturedStmt(); + auto FiniCB = [this](InsertPointTy IP) { + CGBuilderTy::InsertPointGuard IPG(Builder); + assert(IP.getBlock()->end() != IP.getPoint() + && "OpenMP IR Builder should cause terminated block!"); + + llvm::BasicBlock *IPBB = IP.getBlock(); + llvm::BranchInst *IPBBTI = + llvm::dyn_cast(IPBB->getTerminator()); + llvm::BasicBlock *DestBB = IPBBTI->getSuccessor(0); + + //erase and replace with cleanup branch. + IPBB->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(IPBB); + CodeGenFunction::JumpDest Dest = getJumpDestInCurrentScope(DestBB); + EmitBranchThroughCleanup(Dest); + }; + + auto BodyGenCB = [MasterRegionBodyStmt, this](InsertPointTy AllocaIP, + InsertPointTy CodeGenIP, llvm::BasicBlock &FiniBB) { + auto OldAllocaIP = AllocaInsertPt; + if (AllocaIP.isSet()) + AllocaInsertPt = &*AllocaIP.getPoint(); + auto OldReturnBlock = ReturnBlock; + ReturnBlock = getJumpDestInCurrentScope(&FiniBB); + + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + if (llvm::Instruction *CodeGenIPBBTI = CodeGenIPBB->getTerminator()) + CodeGenIPBBTI->eraseFromParent(); + + Builder.SetInsertPoint(CodeGenIPBB); + + EmitStmt(MasterRegionBodyStmt); + + Builder.CreateBr(&FiniBB); + + AllocaInsertPt = OldAllocaIP; + ReturnBlock = OldReturnBlock; + }; + + Builder.restoreIP(OMPBuilder->CreateMaster(Builder, BodyGenCB, FiniCB)); + + return; + } OMPLexicalScope Scope(*this, S, OMPD_unknown); emitMaster(*this, S); } void CodeGenFunction::EmitOMPCriticalDirective(const OMPCriticalDirective &S) { + if (llvm::OpenMPIRBuilder *OMPBuilder = CGM.getOpenMPIRBuilder()) { + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + const Stmt *CriticalRegionBodyStmt = + S.getInnermostCapturedStmt()->getCapturedStmt(); + const Expr *Hint = nullptr; + if (const auto *HintClause = S.getSingleClause()) + Hint = HintClause->getHint(); + + //TODO: Fix the type when everything about typing is final. + llvm::Value* HintInst = (!Hint)? nullptr : + Builder.CreateIntCast(EmitScalarExpr(Hint),CGM.Int32Ty,false); + + auto FiniCB = [this](InsertPointTy IP) { + CGBuilderTy::InsertPointGuard IPG(Builder); + assert(IP.getBlock()->end() != IP.getPoint() && + "OpenMP IR Builder should cause terminated block!"); + llvm::BasicBlock *IPBB = IP.getBlock(); + llvm::BranchInst* IPBBTI = + llvm::dyn_cast(IPBB->getTerminator()); + llvm::BasicBlock *DestBB = IPBBTI->getSuccessor(0); + + //erase and replace with cleanup branch. + IPBB->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(IPBB); + CodeGenFunction::JumpDest Dest = getJumpDestInCurrentScope(DestBB); + EmitBranchThroughCleanup(Dest); + }; + + auto BodyGenCB = [CriticalRegionBodyStmt, this] + (InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + llvm::BasicBlock &FiniBB) { + auto OldAllocaIP = AllocaInsertPt; + if (AllocaIP.isSet()) AllocaInsertPt = &*AllocaIP.getPoint(); + auto OldReturnBlock = ReturnBlock; + ReturnBlock = getJumpDestInCurrentScope(&FiniBB); + + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + if (llvm::Instruction *CodeGenIPBBTI = CodeGenIPBB->getTerminator()) + CodeGenIPBBTI->eraseFromParent(); + + Builder.SetInsertPoint(CodeGenIPBB); + + EmitStmt(CriticalRegionBodyStmt); + + Builder.CreateBr(&FiniBB); + + AllocaInsertPt = OldAllocaIP; + ReturnBlock = OldReturnBlock; + }; + + Builder.restoreIP(OMPBuilder->CreateCritical(Builder, BodyGenCB, FiniCB, + S.getDirectiveName().getAsString(), HintInst)); + + return; + } + auto &&CodeGen = [&S](CodeGenFunction &CGF, PrePostActionTy &Action) { Action.Enter(CGF); CGF.EmitStmt(S.getInnermostCapturedStmt()->getCapturedStmt()); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -27,7 +27,8 @@ public: /// Create a new OpenMPIRBuilder operating on the given module \p M. This will /// not have an effect on \p M (see initialize). - OpenMPIRBuilder(Module &M) : M(M), Builder(M.getContext()) {} + OpenMPIRBuilder(Module &M) : M(M), Builder(M.getContext()), + KmpCriticalNameTy(nullptr), KmpCriticalNamePtrTy(nullptr) {} /// Initialize the internal state, this will put structures types and /// potentially other helpers into the underlying module. Must be called @@ -243,6 +244,112 @@ /// Map to remember existing ident_t*. DenseMap, GlobalVariable *> IdentMap; + + /// Type for kmp_critical_name[8], and related pointer type; + llvm::ArrayType *KmpCriticalNameTy; + llvm::PointerType *KmpCriticalNamePtrTy; + + /// An ordered map of auto-generated variables to their unique names. + /// It stores variables with the following names: 1) ".gomp_critical_user_" + + /// + ".var" for "omp critical" directives; 2) + /// + ".cache." for cache for threadprivate + /// variables. + StringMap, BumpPtrAllocator> InternalVars; + +public: + + ///TODO add other OMP call types as they come + enum callType { Entry, Entry_Hint, Exit }; + + CallInst* CreateOMPCall(omp::Directive OMPD, callType ct, + ArrayRef Args); + + CallInst* CreateEntryCall(omp::Directive OMPD, ArrayRef Args); + + CallInst* CreateExitCall(omp::Directive OMPD, ArrayRef Args); + + /// Common interface for generating entry calls for OMP Directive + /// + /// \param OMPD Directive to generate runtime entry call for + /// \param EntryArgs entry runtime function call args + /// \param Conditional indicate if the entry call result will + /// be used to evaluate a conditional + /// + /// \return The insertion position in exit block + InsertPointTy emitCommonDirectiveEntry(omp::Directive OMPD, + llvm::Value *EntryCall, bool conditional = false); + + /// Common interface for generating entry calls for OMP Directives. + /// if the directive has a region/body, It will set the insertion + /// point to the body + /// + /// \param OMPD Directive to generate runtime entry call for + /// \param EntryArgs entry runtime function call args + /// \param ExitBB block where the region ends. + /// \param Conditional indicate if the entry call result will + /// be used to evaluate a conditional + /// + /// \return The insertion position in exit block + InsertPointTy emitCommonDirectiveEntry(omp::Directive OMPD, + llvm::Value *EntryCall, BasicBlock *ExitBB, bool conditional = false); + + /// Common interface to generate exit calls and -if needed- finalize the region + /// + /// \param IP Insertion point for emitting Finalization code and exit call + /// \param ExitBB Exit BasicBlock for OMP region + /// \param OMPD Directive to generate runtime entry call for + /// \param ExitArgs exit runtime function call args + /// \param hasFinalize indicate if the directive will require finalization + /// and and has a callback in the finalization stack that should be + /// called. + /// + /// \return The insertion position in exit block + InsertPointTy emitCommonDirectiveExit(InsertPointTy IP, BasicBlock *ExitBB, + omp::Directive OMPD, ArrayRef ExitArgs, bool hasFinalize); + + /// Generator for '#omp master' + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the region code. + /// + /// \returns The insertion position *after* the master. + InsertPointTy CreateMaster(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB); + + /// Generator for '#omp master' + /// + /// \param Loc The insert and source location description. + /// \param BodyGenCB Callback that will generate the region code. + /// \param CriticalName name of the lock used by the critical directive + /// \param hasHint whether there is ahint clause associated with critical + /// \param FiniCB Callback to finalize variable copies. + /// + /// \returns The insertion position *after* the master. + InsertPointTy CreateCritical(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, + StringRef CriticalName, Value *HintInst); + +private: + + InsertPointTy EmitOMPInlinedRegion(Instruction *EntryCall, + omp::Directive OMPD, BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB, ArrayRef ExitArgs, + bool conditional, bool hasFinalize); + + std::string getName(ArrayRef Parts, StringRef FirstSeparator, + StringRef Separator) const; + + Constant* getOrCreateOMPInternalVariable(llvm::Type *Ty, + const llvm::Twine &Name, unsigned AddressSpace = 0); + + Value* getOMPCriticalRegionLock(StringRef CriticalName); + + ///nice interface to make it easier to retrieve RTL OpenMP funtion names + ///TODO : pick a better name. + omp::RuntimeFunction RTLFuncName(omp::Directive omp, callType ct); + + llvm::Constant* getOrCreateInternalVariable(llvm::Type *Ty, + const llvm::Twine &Name, unsigned AddressSpace); }; } // end namespace llvm diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -176,6 +176,12 @@ __OMP_RTL(omp_get_thread_num, false, Int32, ) +__OMP_RTL(__kmpc_omp_master, false, Int32, IdentPtr, Int32) +__OMP_RTL(__kmpc_omp_end_master, false, Void, IdentPtr, Int32) +__OMP_RTL(__kmpc_omp_critical, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy) +__OMP_RTL(__kmpc_omp_critical_with_hint, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy, Int32) +__OMP_RTL(__kmpc_omp_end_critical, false, Void, IdentPtr, Int32, KmpCriticalNamePtrTy) + #undef __OMP_RTL #undef OMP_RTL diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -91,7 +91,12 @@ return Fn; } -void OpenMPIRBuilder::initialize() { initializeTypes(M); } +void OpenMPIRBuilder::initialize() { + initializeTypes(M); + KmpCriticalNameTy = ArrayType::get( + Type::getInt32Ty(M.getContext()),/*NumElements*/ 8); + KmpCriticalNamePtrTy = PointerType::getUnqual(KmpCriticalNameTy); +} Value *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr, IdentFlag LocFlags) { @@ -630,3 +635,255 @@ return AfterIP; } + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::CreateMaster(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB) { + + if (!updateToLocation(Loc)) + return Loc.IP; + + Directive OMPD = Directive::OMPD_master; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *ThreadId = getOrCreateThreadID(Ident); + Value *Args[] = { Ident, ThreadId }; + Instruction *EntryCall = CreateEntryCall(OMPD, Args); + + return + EmitOMPInlinedRegion(EntryCall, OMPD, BodyGenCB, FiniCB, + Args, /*Conditional*/true, /*hasFinalize*/true); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::CreateCritical(const LocationDescription &Loc, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, + StringRef CriticalName, Value *HintInst) { + + if (!updateToLocation(Loc)) + return Loc.IP; + + Directive OMPD = Directive::OMPD_critical; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *ThreadId = getOrCreateThreadID(Ident); + Value *LockVar = getOMPCriticalRegionLock(CriticalName); + llvm::Value *Args[] = { Ident, ThreadId, LockVar }; + + SmallVector EnterArgs(std::begin(Args), std::end(Args)); + Instruction *EntryCall = nullptr; + if (HintInst) { + // Add Hint to entry Args and create call + EnterArgs.push_back(HintInst); + EntryCall = CreateOMPCall(OMPD, Entry_Hint, EnterArgs); + } else { + EntryCall = CreateEntryCall(OMPD, EnterArgs); + } + + return + EmitOMPInlinedRegion(EntryCall, OMPD, BodyGenCB, FiniCB, Args, + /*Conditional*/false, /*hasFinalize*/true); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::EmitOMPInlinedRegion(Instruction *EntryCall, omp::Directive OMPD, + BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, + ArrayRef ExitArgs, bool conditional, + bool hasFinalize) { + + FinalizationStack.push_back( { FiniCB, OMPD, /*IsCancellable*/false }); + + // Create 'Critical' entry and body blocks, in preparation + // for conditional creation + BasicBlock *EntryBB = Builder.GetInsertBlock(); + Instruction *SplitPos = EntryBB->getTerminator(); + if (!isa_and_nonnull(SplitPos)) + SplitPos = new UnreachableInst(Builder.getContext(), EntryBB); + BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end"); + BasicBlock *FiniBB = EntryBB->splitBasicBlock(EntryBB->getTerminator(), + "omp_region.finalize"); + + Builder.SetInsertPoint(EntryBB->getTerminator()); + emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, conditional); + + //generate body + BodyGenCB( /* AllocaIP */InsertPointTy(), + /* CodeGenIP */Builder.saveIP(), *FiniBB); + + // emit exit call and do any needed finalization. + auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt()); + assert(FiniBB->getTerminator()->getNumSuccessors() == 1 + && FiniBB->getTerminator()->getSuccessor(0) == ExitBB + && "Unexpected insertion point for finalization call!"); + emitCommonDirectiveExit(FinIP, ExitBB, OMPD, ExitArgs, /*hasFinalize*/true); + + BasicBlock *IPBB = SplitPos->getParent(); + assert(IPBB == ExitBB && "Unexpected Insertion point location!"); + + if (!isa_and_nonnull(SplitPos)) { + SplitPos->eraseFromParent(); + } + + Builder.SetInsertPoint(ExitBB); + + return Builder.saveIP(); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::emitCommonDirectiveEntry(omp::Directive omp, + llvm::Value *EntryCall, BasicBlock *ExitBB, + bool conditional) { + + OpenMPIRBuilder::InsertPointTy Contpt; + llvm::BasicBlock *EntryBB = Builder.GetInsertBlock(); + + if (conditional) { + llvm::Value *CallBool = Builder.CreateIsNotNull(EntryCall); + auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body"); + auto *UI = new UnreachableInst(Builder.getContext(), ThenBB); + + // Emit thenBB and set the Builder's insertion point there for + // body generation next. Place the block after the current block. + llvm::Function *CurFn = EntryBB->getParent(); + CurFn->getBasicBlockList().insertAfter(EntryBB->getIterator(), ThenBB); + + // Move Entry branch to end of ThenBB, and replace with conditional + // branch (If-stmt) + Instruction *EntryBBTI = EntryBB->getTerminator(); + Builder.CreateCondBr(CallBool, ThenBB, ExitBB); + EntryBBTI->removeFromParent(); + Builder.SetInsertPoint(UI); + Builder.Insert(EntryBBTI); + UI->eraseFromParent(); + Builder.SetInsertPoint(ThenBB->getTerminator()); + + //return an insertion point to ExitBB. + Contpt = IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt()); + } else + //otherwise Return an insertion point to current block + Contpt = Builder.saveIP(); + + return Contpt; +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::emitCommonDirectiveExit( InsertPointTy IP, + BasicBlock *ExitBB, omp::Directive OMPD, + ArrayRef ExitArgs, bool hasFinalize) { + + IRBuilder<>::InsertPointGuard IPG(Builder); + Builder.restoreIP(IP); + + // If there is finalization to do, emit it before the exit call + if (hasFinalize) { + assert(!FinalizationStack.empty() && + "Unexpected finalization stack state!"); + + FinalizationInfo Fi = FinalizationStack.pop_back_val(); + assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!"); + + Fi.FiniCB(IP); + + BasicBlock *InsertBB = IP.getBlock(); + Instruction *InsertBBTI = InsertBB->getTerminator(); + if (!(InsertBBTI)) { + Builder.SetInsertPoint(InsertBB); + InsertBBTI = Builder.CreateBr(ExitBB); + } + + // set Builder IP for call creation + Builder.SetInsertPoint(InsertBBTI); + } + + CallInst *exitcall = CreateExitCall(OMPD, ExitArgs); + + return IRBuilder<>::InsertPoint(exitcall->getParent(), + exitcall->getIterator()); +} + +CallInst* OpenMPIRBuilder::CreateOMPCall(omp::Directive OMPD, callType ct, + ArrayRef Args) { + Function *fn = getOrCreateRuntimeFunction(RTLFuncName(OMPD, ct)); + return Builder.CreateCall(fn, Args); +} + +CallInst* OpenMPIRBuilder::CreateEntryCall(omp::Directive OMPD, + ArrayRef Args) { + Function *fn = getOrCreateRuntimeFunction(RTLFuncName(OMPD, Entry)); + return Builder.CreateCall(fn, Args); +} + +CallInst* OpenMPIRBuilder::CreateExitCall(omp::Directive OMPD, + ArrayRef Args) { + Function *fn = getOrCreateRuntimeFunction(RTLFuncName(OMPD, Exit)); + return Builder.CreateCall(fn, Args); +} + +omp::RuntimeFunction OpenMPIRBuilder::RTLFuncName(omp::Directive omp, + callType ct) { + omp::RuntimeFunction RTLFunc; + switch (omp) { + case OMPD_master: + if (ct == Entry) + RTLFunc = OMPRTL___kmpc_omp_master; + else if (ct == Exit) + RTLFunc = OMPRTL___kmpc_omp_end_master; + else + assert(false && "Unknown master call type"); + break; + case OMPD_critical: + if (ct == Entry) + RTLFunc = OMPRTL___kmpc_omp_critical; + else if (ct == Entry_Hint) + RTLFunc = OMPRTL___kmpc_omp_critical_with_hint; + else if (ct == Exit) + RTLFunc = OMPRTL___kmpc_omp_end_critical; + else + assert(false && "Unknown critical call type"); + break; + default: + assert(false && "unknown OMP directive"); + + } + //assert(RTLFunc && "Unable to find OpenMP runtime function"); + return RTLFunc; +} + +std::string OpenMPIRBuilder::getName(ArrayRef Parts, + StringRef FirstSeparator, StringRef Separator) const { + SmallString<128> Buffer; + llvm::raw_svector_ostream OS(Buffer); + StringRef Sep = FirstSeparator; + for (StringRef Part : Parts) { + OS << Sep << Part; + Sep = Separator; + } + return OS.str(); +} + +Constant* OpenMPIRBuilder::getOrCreateOMPInternalVariable(llvm::Type *Ty, + const llvm::Twine &Name, unsigned AddressSpace) { + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + Out << Name; + StringRef RuntimeName = Out.str(); + auto &Elem = *InternalVars.try_emplace(RuntimeName, nullptr).first; + if (Elem.second) { + assert(Elem.second->getType()->getPointerElementType() == Ty + && "OMP internal variable has different type than requested"); + return &*Elem.second; + } + + return Elem.second = + new llvm::GlobalVariable(M, Ty, /*IsConstant*/false, + llvm::GlobalValue::CommonLinkage, llvm::Constant::getNullValue(Ty), + Elem.first(), /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, + AddressSpace); +} + +Value* OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) { + std::string Prefix = Twine("gomp_critical_user_", CriticalName).str(); + std::string Name = getName( { Prefix, "var" }, ".", "."); + return getOrCreateOMPInternalVariable(KmpCriticalNameTy, Name); +} + diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -613,4 +613,180 @@ } } +TEST_F(OpenMPIRBuilderTest, MasterDirective) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc( { Builder.saveIP(), DL }); + + AllocaInst *PrivAI = nullptr; + + BasicBlock *EntryBB = nullptr; + BasicBlock *FinalBB = nullptr; + BasicBlock *ExitBB = nullptr; + BasicBlock *ThenBB = nullptr; + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + if (AllocaIP.isSet()) + Builder.restoreIP(AllocaIP); + else + Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt())); + PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + Builder.CreateStore(F->arg_begin(), PrivAI); + + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint(); + EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst); + + Builder.restoreIP(CodeGenIP); + + //collect some info for checks later + FinalBB = &FiniBB; + ExitBB = FiniBB.getUniqueSuccessor(); + ThenBB = Builder.GetInsertBlock(); + EntryBB = ThenBB->getUniquePredecessor(); + + //simple instructions for body + Value *PrivLoad = Builder.CreateLoad(PrivAI, "local.use"); + Builder.CreateICmpNE(F->arg_begin(), PrivLoad); + }; + + auto FiniCB = [&](InsertPointTy IP) { + BasicBlock *IPBB = IP.getBlock(); + EXPECT_NE(IPBB->end(), IP.getPoint()); + }; + + Builder.restoreIP(OMPBuilder.CreateMaster(Builder, BodyGenCB, FiniCB)); + Value *EntryBBTI = EntryBB->getTerminator(); + EXPECT_NE(EntryBBTI, nullptr); + EXPECT_TRUE(isa(EntryBBTI)); + BranchInst *EntryBr = cast(EntryBB->getTerminator()); + EXPECT_TRUE(EntryBr->isConditional()); + EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB); + EXPECT_EQ(ThenBB->getUniqueSuccessor(), FinalBB); + EXPECT_EQ(FinalBB->getUniqueSuccessor(), ExitBB); + EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB); + + CmpInst *CondInst = cast(EntryBr->getCondition()); + EXPECT_TRUE(isa(CondInst->getOperand(0))); + + CallInst *MasterEntryCI = cast(CondInst->getOperand(0)); + EXPECT_EQ(MasterEntryCI->getNumArgOperands(), 2U); + EXPECT_EQ(MasterEntryCI->getCalledFunction()->getName(), "__kmpc_omp_master"); + EXPECT_TRUE(isa(MasterEntryCI->getArgOperand(0))); + + CallInst *MasterEndCI = nullptr; + for (auto FI = FinalBB->begin(); FI != FinalBB->end(); ++FI) { + Instruction *cur = &*FI; + if (isa(cur)) { + MasterEndCI = cast(cur); + if (MasterEndCI->getCalledFunction()->getName() + == "__kmpc_omp_end_master") + break; + else + MasterEndCI = nullptr; + } + } + EXPECT_NE(MasterEndCI, nullptr); + EXPECT_EQ(MasterEndCI->getNumArgOperands(), 2U); + EXPECT_TRUE(isa(MasterEndCI->getArgOperand(0))); + EXPECT_EQ(MasterEndCI->getArgOperand(1), MasterEntryCI->getArgOperand(1)); +} + +TEST_F(OpenMPIRBuilderTest, CriticalDirective) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc( { Builder.saveIP(), DL }); + + AllocaInst *PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + + BasicBlock *EntryBB = nullptr; + BasicBlock *FinalBB = nullptr; + BasicBlock *ExitBB = nullptr; + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + //collect some info for checks later + FinalBB = &FiniBB; + ExitBB = FiniBB.getUniqueSuccessor(); + EntryBB = FinalBB->getUniquePredecessor(); + + //actual start for bodyCB + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint(); + EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst); + EXPECT_EQ(EntryBB, CodeGenIPBB); + + //body begin + Builder.restoreIP(CodeGenIP); + Builder.CreateStore(F->arg_begin(), PrivAI); + Value *PrivLoad = Builder.CreateLoad(PrivAI, "local.use"); + Builder.CreateICmpNE(F->arg_begin(), PrivLoad); + }; + + auto FiniCB = [&](InsertPointTy IP) { + BasicBlock *IPBB = IP.getBlock(); + EXPECT_NE(IPBB->end(), IP.getPoint()); + }; + + Builder.restoreIP( + OMPBuilder.CreateCritical(Builder, BodyGenCB, FiniCB, "testCRT", + nullptr)); + + Value *EntryBBTI = EntryBB->getTerminator(); + EXPECT_NE(EntryBBTI, nullptr); + EXPECT_TRUE(isa(EntryBBTI)); + BranchInst *EntryBr = cast(EntryBB->getTerminator()); + EXPECT_FALSE(EntryBr->isConditional()); + EXPECT_EQ(EntryBB->getUniqueSuccessor(), FinalBB); + EXPECT_EQ(FinalBB->getUniqueSuccessor(), ExitBB); + + CallInst *CriticalEntryCI = nullptr; + for (auto EI = EntryBB->begin(); EI != EntryBB->end(); ++EI) { + Instruction *cur = &*EI; + if (isa(cur)) { + CriticalEntryCI = cast(cur); + if (CriticalEntryCI->getCalledFunction()->getName() + == "__kmpc_omp_critical") + break; + else + CriticalEntryCI = nullptr; + } + } + EXPECT_NE(CriticalEntryCI, nullptr); + EXPECT_EQ(CriticalEntryCI->getNumArgOperands(), 3U); + EXPECT_EQ(CriticalEntryCI->getCalledFunction()->getName(), + "__kmpc_omp_critical"); + EXPECT_TRUE(isa(CriticalEntryCI->getArgOperand(0))); + + CallInst *CriticalEndCI = nullptr; + for (auto FI = FinalBB->begin(); FI != FinalBB->end(); ++FI) { + Instruction *cur = &*FI; + if (isa(cur)) { + CriticalEndCI = cast(cur); + if (CriticalEndCI->getCalledFunction()->getName() + == "__kmpc_omp_end_critical") + break; + else + CriticalEndCI = nullptr; + } + } + EXPECT_NE(CriticalEndCI, nullptr); + EXPECT_EQ(CriticalEndCI->getNumArgOperands(), 3U); + EXPECT_TRUE(isa(CriticalEndCI->getArgOperand(0))); + EXPECT_EQ(CriticalEndCI->getArgOperand(1), CriticalEntryCI->getArgOperand(1)); + PointerType *CriticalNamePtrTy = PointerType::getUnqual( + ArrayType::get(Type::getInt32Ty(Ctx), 8)); + EXPECT_EQ(CriticalEndCI->getArgOperand(2), CriticalEntryCI->getArgOperand(2)); + EXPECT_EQ(CriticalEndCI->getArgOperand(2)->getType(), CriticalNamePtrTy); +} + } // namespace