diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -21,6 +21,35 @@ namespace llvm { +enum OpenMPSchedType { + /// Lower bound for default (unordered) versions. + OMP_sch_lower = 32, + OMP_sch_static_chunked = 33, + OMP_sch_static = 34, + OMP_sch_dynamic_chunked = 35, + OMP_sch_guided_chunked = 36, + OMP_sch_runtime = 37, + OMP_sch_auto = 38, + /// static with chunk adjustment (e.g., simd) + OMP_sch_static_balanced_chunked = 45, + /// Lower bound for 'ordered' versions. + OMP_ord_lower = 64, + OMP_ord_static_chunked = 65, + OMP_ord_static = 66, + OMP_ord_dynamic_chunked = 67, + OMP_ord_guided_chunked = 68, + OMP_ord_runtime = 69, + OMP_ord_auto = 70, + OMP_sch_default = OMP_sch_static, + /// dist_schedule types + OMP_dist_sch_static_chunked = 91, + OMP_dist_sch_static = 92, + /// Support for OpenMP 4.5 monotonic and nonmonotonic schedule modifiers. + /// Set if the monotonic schedule modifier was present. + OMP_sch_modifier_monotonic = (1 << 29), + /// Set if the nonmonotonic schedule modifier was present. + OMP_sch_modifier_nonmonotonic = (1 << 30), +}; /// An interface to create LLVM-IR for OpenMP directives. /// /// Each OpenMP directive has a corresponding public generator method. @@ -43,6 +72,9 @@ /// Type used throughout for insertion points. using InsertPointTy = IRBuilder<>::InsertPoint; + using BGenCallbackTy = + std::function; /// Callback type for variable finalization (think destructors). /// @@ -363,12 +395,27 @@ /// \param CriticalName name of the lock used by the critical directive /// \param HintInst Hint Instruction for hint clause associated with critical /// - /// \returns The insertion position *after* the master. + /// \returns The insertion position *after* the critical. InsertPointTy CreateCritical(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst); + /// Generator for '#omp sections' + /// + /// \param Loc The insert and source location description. + /// \param SectionCBs Callbacks that will generate body of each section. + /// \param PrivCB Callback to copy a given variable (think copy constructor). + /// \param FiniCB Callback to finalize variable copies. + /// \param IsCancellable Flag to indicate a cancellable parallel region. + /// + /// \returns The insertion position *after* the sections. + InsertPointTy CreateSections(const LocationDescription &Loc, + ArrayRef SectionCBs, + PrivatizeCallbackTy PrivCB, + FinalizeCallbackTy FiniCB, bool IsCancellable, + bool IsNoWait); + /// Generate conditional branch and relevant BasicBlocks through which private /// threads copy the 'copyin' variables from Master copy to threadprivate /// copies. diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -1649,6 +1649,14 @@ return CreateAlignedStore(Val, Ptr, MaybeAlign(), isVolatile); } + Value *CreateLVal(Type *Ty, const Twine &Name = "", Value *Init = nullptr) { + Value *LVal = CreateAlloca(Ty, nullptr, Name); + if (Init) { + CreateStore(Init, LVal); + } + return LVal; + } + LLVM_ATTRIBUTE_DEPRECATED(LoadInst *CreateAlignedLoad(Type *Ty, Value *Ptr, unsigned Align, const char *Name), diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -780,6 +780,133 @@ emitTaskyieldImpl(Loc); } +// TODO: Handle privatisation callbacks +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::CreateSections( + const LocationDescription &Loc, ArrayRef SectionCBs, + PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, bool IsCancellable, + bool IsNoWait) { + if (!updateToLocation(Loc)) + return Loc.IP; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *ThreadID = getOrCreateThreadID(Ident); + + BasicBlock *InsertBB = Builder.GetInsertBlock(); + Function *CurFn = InsertBB->getParent(); + + // allocate and initialize helper vars for for loop + Value *LB = Builder.CreateLVal(Int32, "omp.sections.lb", Builder.getInt32(0)); + ConstantInt *GlobalUBVal = Builder.getInt32(SectionCBs.size() - 1); + Value *UB = Builder.CreateLVal(Int32, ".omp.sections.ub", GlobalUBVal); + Value *ST = + Builder.CreateLVal(Int32, ".omp.sections.st.", Builder.getInt32(1)); + Value *IL = + Builder.CreateLVal(Int32, ".omp.sections.il.", Builder.getInt32(0)); + Value *IV = Builder.CreateLVal(Int32, ".omp.sections.iv."); + + auto CreateForLoopCB = [&]() { + Instruction *LBRef = Builder.CreateLoad(LB); + Builder.CreateStore(LBRef, IV); + + // create new basic blocks for emitting the for loop + auto *ForBodyBB = BasicBlock::Create(M.getContext(), "omp.inner.for.body"); + auto *ForExitBB = BasicBlock::Create(M.getContext(), "omp.inner.for.exit"); + auto *ForIncBB = BasicBlock::Create(M.getContext(), "omp.inner.for.inc"); + CurFn->getBasicBlockList().insertAfter(InsertBB->getIterator(), ForIncBB); + CurFn->getBasicBlockList().insertAfter(InsertBB->getIterator(), ForBodyBB); + CurFn->getBasicBlockList().insertAfter(InsertBB->getIterator(), ForExitBB); + + // split InsertBB to create the basic block for emitting for-loop-condition + auto *UI = new UnreachableInst(Builder.getContext(), InsertBB); + BasicBlock *ForCondBB = InsertBB->splitBasicBlock(UI, "omp.inner.for.cond"); + UI->eraseFromParent(); + Builder.SetInsertPoint(ForCondBB); + Instruction *IVRef = Builder.CreateLoad(IV); + Instruction *UBRef = Builder.CreateLoad(UB); + Value *cmpRef = Builder.CreateICmpSLE(IVRef, UBRef, "cmp"); + Builder.CreateCondBr(cmpRef, ForBodyBB, ForExitBB); + + // callback for creating switch statement inside for body. + auto CreateSwitchCB = [&]() { + auto *SwitchExitBB = + BasicBlock::Create(M.getContext(), "omp.switch.exit"); + CurFn->getBasicBlockList().insertAfter(InsertBB->getIterator(), + SwitchExitBB); + + SwitchInst *SwitchStmt = + Builder.CreateSwitch(Builder.CreateLoad(IV), SwitchExitBB); + // each section in emitted as a switch case + // Iterate through all sections and emit a switch construct: + // switch (IV) { + // case 0: + // ; + // break; + // ... + // case - 1: + // - 1]>; + // break; + // } + // .omp.sections.exit: + unsigned CaseNumber = 0; + for (auto SectionCB : SectionCBs) { + auto *CaseBB = BasicBlock::Create(M.getContext(), ".omp.sections.case"); + CurFn->getBasicBlockList().insertAfter(InsertBB->getIterator(), CaseBB); + SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB); + Builder.SetInsertPoint(CaseBB); + SectionCB(InsertPointTy(), Builder.saveIP(), *SwitchExitBB); + CaseNumber++; + } + Builder.SetInsertPoint(SwitchExitBB); + }; + + // for Body + Builder.SetInsertPoint(ForBodyBB); + CreateSwitchCB(); + Builder.CreateBr(ForIncBB); + + // for inc + Builder.SetInsertPoint(ForIncBB); + Instruction *IVRef2 = Builder.CreateLoad(IV); + Value *Inc = Builder.CreateNSWAdd(IVRef2, Builder.getInt32(1), "inc"); + Builder.CreateStore(Inc, IV); + Builder.CreateBr(ForCondBB); + Builder.SetInsertPoint(ForExitBB); + }; + + // emit kmpc_for + // TODO: Change the following to call the IR Builder function to emit kmpc_for + // after for loop is lowered. + Value *ForEntryArgs[] = { + Ident, + ThreadID, + Builder.getInt32(OMP_sch_static), // Schedule type + IL, // &isLastIter + LB, // &LB + UB, // &UB + ST, // &Stride + Builder.getInt32(1), // Incr + Builder.getInt32(1) // Chunk + }; + Function *EntryRTLFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_for_static_init_4); + Builder.CreateCall(EntryRTLFn, ForEntryArgs); + + CreateForLoopCB(); + + Value *ForExitArgs[] = {Ident, ThreadID}; + Function *ExitRTLFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_for_static_fini); + Builder.CreateCall(ExitRTLFn, ForExitArgs); + + if (!IsNoWait) { + CreateBarrier(Builder, OMPD_sections, true, IsCancellable); + } + + // TODO: Handle PrivCB + + return Builder.saveIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::CreateMaster(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -1101,4 +1101,58 @@ EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1)); } +TEST_F(OpenMPIRBuilderTest, CreateSections) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + using BodyGenCallbackTy = + std::function; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + llvm::SmallVector SectionCBVector; + + AllocaInst *PrivAI = nullptr; + + unsigned NumBodiesGenerated = 0; + unsigned NumFiniCBCalls = 0; + PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + + auto FiniCB = [&](InsertPointTy IP) { ++NumFiniCBCalls; }; + + auto SectionCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + ++NumBodiesGenerated; + Builder.restoreIP(CodeGenIP); + Builder.CreateStore(F->arg_begin(), PrivAI); + Builder.CreateLoad(PrivAI, "local.use"); + Builder.CreateBr(&FiniBB); + }; + auto PrivCB = [](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + llvm::Value &Val, llvm::Value *&ReplVal) { + // TODO: Privatization not implemented in OMPIRBuilder yet + return CodeGenIP; + }; + + SectionCBVector.push_back(SectionCB); + SectionCBVector.push_back(SectionCB); + ArrayRef SectionCBs = makeArrayRef(SectionCBVector); + + Builder.restoreIP( + OMPBuilder.CreateSections(Loc, SectionCBs, PrivCB, FiniCB, false, false)); + Builder.CreateRetVoid(); // Required at the end of the function + + EXPECT_NE(PrivAI, nullptr); + Function *OutlinedFn = PrivAI->getFunction(); + EXPECT_EQ(F, OutlinedFn); + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_EQ(OutlinedFn->arg_size(), 1U); + EXPECT_EQ(OutlinedFn->getBasicBlockList().size(), 8U); + + ASSERT_EQ(NumBodiesGenerated, 2U); + ASSERT_EQ(NumFiniCBCalls, 0U); +} + } // namespace