diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -618,6 +618,16 @@ /// \param Loc The location where the taskyield directive was encountered. void createTaskyield(const LocationDescription &Loc); + /// Generator for `#omp task` + /// + /// \param Loc The location where the task construct was encountered. + /// \param AllocaIP The insertion point to be used for alloca instructions. + /// \param BodyGenCB Callback that will generate the region code. + /// \param Tied True if the task is tied, false if the task is untied. + InsertPointTy createTask(const LocationDescription &Loc, + InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, + bool Tied = true); + /// Functions used to generate reductions. Such functions take two Values /// representing LHS and RHS of the reduction, respectively, and a reference /// to the value that is updated to refer to the reduction result. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1253,6 +1253,172 @@ emitTaskyieldImpl(Loc); } +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createTask(const LocationDescription &Loc, + InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, + bool Tied) { + if (!updateToLocation(Loc)) + return InsertPointTy(); + + // The current basic block is split into four basic blocks. After outlining, + // they will be mapped as follows: + // ``` + // def current_fn() { + // current_basic_block: + // br label %task.exit + // task.exit: + // ; instructions after task + // } + // def outlined_fn() { + // task.alloca: + // br label %task.body + // task.body: + // ret void + // } + // ``` + BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit"); + BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body"); + BasicBlock *TaskAllocaBB = + splitBB(Builder, /*CreateBranch=*/true, "task.alloca"); + + OutlineInfo OI; + OI.EntryBB = TaskAllocaBB; + OI.OuterAllocaBB = AllocaIP.getBlock(); + OI.ExitBB = TaskExitBB; + OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) { + // The input IR here looks like the following- + // ``` + // func @current_fn() { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + // + // This is changed to the following- + // + // ``` + // func @current_fn() { + // runtime_call(..., wrapper_fn, ...) + // } + // func @wrapper_fn(..., %args) { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + + // The stale call instruction will be replaced with a new call instruction + // for runtime call with a wrapper function. + assert(OutlinedFn.getNumUses() == 1 && + "there must be a single user for the outlined function"); + CallInst *StaleCI = cast(OutlinedFn.user_back()); + + // HasTaskData is true if any variables are captured in the outlined region, + // false otherwise. + bool HasTaskData = StaleCI->arg_size() > 0; + Builder.SetInsertPoint(StaleCI); + + // Gather the arguments for emitting the runtime call for + // @__kmpc_omp_task_alloc + Function *TaskAllocFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc); + + // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID) + // call. + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + Value *ThreadID = getOrCreateThreadID(Ident); + + // Argument - `flags` + // If task is tied, then (Flags & 1) == 1. + // If task is untied, then (Flags & 1) == 0. + // TODO: Handle the other flags. + Value *Flags = Builder.getInt32(Tied); + + // Argument - `sizeof_kmp_task_t` (TaskSize) + // Tasksize refers to the size in bytes of kmp_task_t data structure + // including private vars accessed in task. + Value *TaskSize = Builder.getInt64(0); + if (HasTaskData) { + AllocaInst *ArgStructAlloca = + dyn_cast(StaleCI->getArgOperand(0)); + assert(ArgStructAlloca && + "Unable to find the alloca instruction corresponding to arguments " + "for extracted function"); + StructType *ArgStructType = + dyn_cast(ArgStructAlloca->getAllocatedType()); + assert(ArgStructType && "Unable to find struct type corresponding to " + "arguments for extracted function"); + TaskSize = + Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); + } + + // TODO: Argument - sizeof_shareds + + // Argument - task_entry (the wrapper function) + // If the outlined function has some captured variables (i.e. HasTaskData is + // true), then the wrapper function will have an additional argument (the + // struct containing captured variables). Otherwise, no such argument will + // be present. + SmallVector WrapperArgTys{Builder.getInt32Ty()}; + if (HasTaskData) + WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); + FunctionCallee WrapperFuncVal = M.getOrInsertFunction( + (Twine(OutlinedFn.getName()) + ".wrapper").str(), + FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); + Function *WrapperFunc = dyn_cast(WrapperFuncVal.getCallee()); + PointerType *WrapperFuncBitcastType = + FunctionType::get(Builder.getInt32Ty(), + {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false) + ->getPointerTo(); + Value *WrapperFuncBitcast = + ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType); + + // Emit the @__kmpc_omp_task_alloc runtime call + // The runtime call returns a pointer to an area where the task captured + // variables must be copied before the task is run (NewTaskData) + CallInst *NewTaskData = Builder.CreateCall( + TaskAllocFn, + {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags, + /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0), + /*task_func=*/WrapperFuncBitcast}); + + // Copy the arguments for outlined function + if (HasTaskData) { + Value *TaskData = StaleCI->getArgOperand(0); + Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); + Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment, + TaskSize); + } + + // Emit the @__kmpc_omp_task runtime call to spawn the task + Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); + Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + + StaleCI->eraseFromParent(); + + // Emit the body for wrapper function + BasicBlock *WrapperEntryBB = + BasicBlock::Create(M.getContext(), "", WrapperFunc); + Builder.SetInsertPoint(WrapperEntryBB); + if (HasTaskData) + Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)}); + else + Builder.CreateCall(&OutlinedFn); + Builder.CreateRet(Builder.getInt32(0)); + }; + + addOutlineInfo(std::move(OI)); + + InsertPointTy TaskAllocaIP = + InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); + InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); + BodyGenCB(TaskAllocaIP, TaskBodyIP); + Builder.SetInsertPoint(TaskExitBB); + + return Builder.saveIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections( const LocationDescription &Loc, InsertPointTy AllocaIP, ArrayRef SectionCBs, PrivatizeCallbackTy PrivCB, diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4412,4 +4412,169 @@ EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy()); } +TEST_F(OpenMPIRBuilderTest, CreateTask) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty()); + AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty()); + Value *Val128 = + Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load"); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(AllocaIP); + AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr, + "bodygen.alloca128"); + + Builder.restoreIP(CodeGenIP); + // Loading and storing captured pointer and values + Builder.CreateStore(Val128, Local128); + Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32, + "bodygen.load32"); + + LoadInst *PrivLoad128 = Builder.CreateLoad( + Local128->getAllocatedType(), Local128, "bodygen.local.load128"); + Value *Cmp = Builder.CreateICmpNE( + Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType())); + Instruction *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(), + &ThenTerm, &ElseTerm); + }; + + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + + // Verify the Ident argument + GlobalVariable *Ident = cast(TaskAllocCall->getArgOperand(0)); + ASSERT_NE(Ident, nullptr); + EXPECT_TRUE(Ident->hasInitializer()); + Constant *Initializer = Ident->getInitializer(); + GlobalVariable *SrcStrGlob = + cast(Initializer->getOperand(4)->stripPointerCasts()); + ASSERT_NE(SrcStrGlob, nullptr); + ConstantDataArray *SrcSrc = + dyn_cast(SrcStrGlob->getInitializer()); + ASSERT_NE(SrcSrc, nullptr); + + // Verify the num_threads argument. + CallInst *GTID = dyn_cast(TaskAllocCall->getArgOperand(1)); + ASSERT_NE(GTID, nullptr); + EXPECT_EQ(GTID->arg_size(), 1U); + EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num"); + + // Verify the flags + // TODO: Check for others flags. Currently testing only for tiedness. + ConstantInt *Flags = dyn_cast(TaskAllocCall->getArgOperand(2)); + ASSERT_NE(Flags, nullptr); + EXPECT_EQ(Flags->getSExtValue(), 1); + + // Verify the data size + ConstantInt *DataSize = + dyn_cast(TaskAllocCall->getArgOperand(3)); + ASSERT_NE(DataSize, nullptr); + EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer + + // TODO: Verify size of shared clause variables + + // Verify Wrapper function + Function *WrapperFunc = + dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); + ASSERT_NE(WrapperFunc, nullptr); + EXPECT_FALSE(WrapperFunc->isDeclaration()); + CallInst *OutlinedFnCall = dyn_cast(WrapperFunc->begin()->begin()); + ASSERT_NE(OutlinedFnCall, nullptr); + EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); + EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1)); + + // Verify the presence of `trunc` and `icmp` instructions in Outlined function + Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); + ASSERT_NE(OutlinedFn, nullptr); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa(&inst); })); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa(&inst); })); + + // Verify the execution of the task + CallInst *TaskCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task) + ->user_back()); + ASSERT_NE(TaskCall, nullptr); + EXPECT_EQ(TaskCall->getArgOperand(0), Ident); + EXPECT_EQ(TaskCall->getArgOperand(1), GTID); + EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall); + + // Verify that the argument data has been copied + for (User *in : TaskAllocCall->users()) { + if (MemCpyInst *memCpyInst = dyn_cast(in)) + EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall); + } +} + +TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + +TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, + /*Tied=*/false)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `Tied` argument + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + ConstantInt *Flags = dyn_cast(TaskAllocCall->getArgOperand(2)); + ASSERT_NE(Flags, nullptr); + EXPECT_EQ(Flags->getZExtValue() & 1U, 0U); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + } // namespace