diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -563,6 +563,14 @@ /// \param Loc The location where the taskyield directive was encountered. void createTaskyield(const LocationDescription &Loc); + /// Generator for `#omp task` + /// + /// \param Loc The location where the task construct was encountered. + /// \param AllocaIP The insertion point to be used for alloca instructions. + /// \param BodyGenCB Callback that will generate the region code. + InsertPointTy createTask(const LocationDescription &Loc, + InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB); + /// Functions used to generate reductions. Such functions take two Values /// representing LHS and RHS of the reduction, respectively, and a reference /// to the value that is updated to refer to the reduction result. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1220,6 +1220,154 @@ emitTaskyieldImpl(Loc); } +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createTask(const LocationDescription &Loc, + InsertPointTy AllocaIP, + BodyGenCallbackTy BodyGenCB) { + if (!updateToLocation(Loc)) + return InsertPointTy(); + + UnreachableInst *UI = Builder.CreateUnreachable(); + + // Basic block mapping for task generation: + // ``` + // current_fn() { + // currbb + // task.exit + // } + // outlined_fn() { + // task.alloca + // task.body + // } + // ``` + BasicBlock *currbb = Builder.GetInsertBlock(); + BasicBlock *TaskAllocaBB = currbb->splitBasicBlock(UI, "task.alloca"); + BasicBlock *TaskBodyBB = TaskAllocaBB->splitBasicBlock(UI, "task.body"); + BasicBlock *TaskExitBB = TaskBodyBB->splitBasicBlock(UI, "task.exit"); + + InsertPointTy TaskAllocaIP = + InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin()); + InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin()); + BodyGenCB(TaskAllocaIP, TaskBodyIP, *TaskExitBB); + + OutlineInfo OI; + OI.EntryBB = TaskAllocaBB; + OI.OuterAllocaBB = AllocaIP.getBlock(); + OI.ExitBB = TaskExitBB; + OI.PostOutlineCB = [this](Function &OutlinedFn) { + // The input IR here looks like the following- + // ``` + // func @current_fn() { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + // + // This is changed to the following- + // + // ``` + // func @current_fn() { + // runtime_call(..., wrapper_fn, ...) + // } + // func @wrapper_fn(..., %args) { + // outlined_fn(%args) + // } + // func @outlined_fn(%args) { ... } + // ``` + + // The stale call instruction will be replaced with a new call instruction + // for runtime call with a wrapper function. + assert(OutlinedFn.getNumUses() == 1 && + "there must be a single user for the outlined function"); + CallInst *StaleCI = cast(OutlinedFn.user_back()); + bool HasTaskData = StaleCI->arg_size() > 0; + Builder.SetInsertPoint(StaleCI); + + // Gather the arguments for emitting the runtime call for + // @__kmpc_omp_task_alloc + Function *TaskAllocFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc); + + // Arguments - loc_ref, gtid + uint32_t SrcLocStrSize; + Value *Ident = getOrCreateIdent( + getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize), + SrcLocStrSize); + Value *ThreadID = getOrCreateThreadID(Ident); + + // Argument - sizeof_kmp_task_t + Value *TaskSize = Builder.getInt64(0); + if (HasTaskData) { + AllocaInst *ArgStructAlloca = + dyn_cast(StaleCI->getArgOperand(0)); + assert(ArgStructAlloca && + "Unable to find the alloca instruction corresponding to arguments " + "for extracted function"); + StructType *ArgStructType = + dyn_cast(ArgStructAlloca->getAllocatedType()); + assert(ArgStructType && "Unable to find struct type corresponding to " + "arguments for extracted function"); + TaskSize = + Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType)); + } + + // TODO: Argument - sizeof_shareds + + // Argument - task_entry (the wrapper function) + SmallVector WrapperArgTys{Builder.getInt32Ty()}; + if (HasTaskData) + WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType()); + std::string WrapperFuncName = OutlinedFn.getName().str() + ".wrapper"; + FunctionCallee WrapperFuncVal = M.getOrInsertFunction( + WrapperFuncName, + FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false)); + Function *WrapperFunc = dyn_cast(WrapperFuncVal.getCallee()); + PointerType *WrapperFuncBitcastType = + FunctionType::get(Builder.getInt32Ty(), + {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false) + ->getPointerTo(); + Value *WrapperFuncBitcast = + ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType); + + // Emit the @__kmpc_omp_task_alloc runtime call + CallInst *NewTaskData = Builder.CreateCall( + TaskAllocFn, + {Ident, ThreadID, /*flags=*/Builder.getInt32(1), + /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0), + /*task_func=*/WrapperFuncBitcast}); + + // Copy the arguments for outlined function + if (HasTaskData) { + Value *TaskData = StaleCI->getArgOperand(0); + Align Alignment = TaskData->getPointerAlignment(M.getDataLayout()); + Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment, + TaskSize); + } + + // Emit the @__kmpc_omp_task runtime call to spawn the task + Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); + Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + + StaleCI->eraseFromParent(); + + // Emit the body for wrapper function + BasicBlock *WrapperEntryBB = BasicBlock::Create(M.getContext()); + WrapperFunc->getBasicBlockList().push_back(WrapperEntryBB); + Builder.SetInsertPoint(WrapperEntryBB); + if (HasTaskData) + Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)}); + else + Builder.CreateCall(&OutlinedFn); + Builder.CreateRet(Builder.getInt32(0)); + }; + + addOutlineInfo(std::move(OI)); + Builder.SetInsertPoint(UI->getParent()); + UI->eraseFromParent(); + + return Builder.saveIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections( const LocationDescription &Loc, InsertPointTy AllocaIP, ArrayRef SectionCBs, PrivatizeCallbackTy PrivCB, diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4476,4 +4476,151 @@ EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy()); } +TEST_F(OpenMPIRBuilderTest, CreateTask) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty()); + AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty()); + Value *Val128 = + Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load"); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationIP) { + Builder.restoreIP(AllocaIP); + AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr, + "bodygen.alloca128"); + + Builder.restoreIP(CodeGenIP); + // Loading and storing captured pointer and values + Builder.CreateStore(Val128, Local128); + Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32, + "bodygen.load32"); + + LoadInst *PrivLoad128 = Builder.CreateLoad( + Local128->getAllocatedType(), Local128, "bodygen.local.load128"); + Value *Cmp = Builder.CreateICmpNE( + Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType())); + Instruction *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(), + &ThenTerm, &ElseTerm); + + Builder.SetInsertPoint(ThenTerm); + Builder.CreateBr(&ContinuationIP); + ThenTerm->eraseFromParent(); + }; + + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + UnreachableInst *UI = Builder.CreateUnreachable(); + BasicBlock *BodyBB = + Builder.GetInsertBlock()->splitBasicBlock(UI, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + UI->eraseFromParent(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + + // Verify the Ident argument + GlobalVariable *Ident = cast(TaskAllocCall->getArgOperand(0)); + EXPECT_NE(Ident, nullptr); + EXPECT_TRUE(Ident->hasInitializer()); + Constant *Initializer = Ident->getInitializer(); + GlobalVariable *SrcStrGlob = + cast(Initializer->getOperand(4)->stripPointerCasts()); + EXPECT_NE(SrcStrGlob, nullptr); + ConstantDataArray *SrcSrc = + dyn_cast(SrcStrGlob->getInitializer()); + EXPECT_NE(SrcSrc, nullptr); + + // Verify the num_threads argument. + CallInst *GTID = dyn_cast(TaskAllocCall->getArgOperand(1)); + EXPECT_NE(GTID, nullptr); + EXPECT_EQ(GTID->arg_size(), 1U); + EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num"); + + // Verify the flags + // TODO: Check for others flags. Currently testing only for tiedness. + ConstantInt *Flags = dyn_cast(TaskAllocCall->getArgOperand(2)); + EXPECT_NE(Flags, nullptr); + EXPECT_EQ(Flags->getSExtValue(), 1); + + // Verify the data size + ConstantInt *DataSize = + dyn_cast(TaskAllocCall->getArgOperand(3)); + EXPECT_NE(DataSize, nullptr); + EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer + + // TODO: Verify size of shared clause variables + + // Verify Wrapper function + Function *WrapperFunc = + dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); + EXPECT_NE(WrapperFunc, nullptr); + EXPECT_FALSE(WrapperFunc->isDeclaration()); + CallInst *OutlinedFnCall = dyn_cast(WrapperFunc->begin()->begin()); + EXPECT_NE(OutlinedFnCall, nullptr); + EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty()); + EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1)); + + // Verify the presence of `trunc` and `icmp` instructions in Outlined function + Function *OutlinedFn = OutlinedFnCall->getCalledFunction(); + EXPECT_NE(OutlinedFn, nullptr); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa(&inst); })); + EXPECT_TRUE(any_of(instructions(OutlinedFn), + [](Instruction &inst) { return isa(&inst); })); + + // Verify the execution of the task + CallInst *TaskCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task) + ->user_back()); + EXPECT_NE(TaskCall, nullptr); + EXPECT_EQ(TaskCall->getArgOperand(0), Ident); + EXPECT_EQ(TaskCall->getArgOperand(1), GTID); + EXPECT_EQ(TaskCall->getArgOperand(2), TaskAllocCall); + + // Verify that the argument data has been copied + for (User *in : TaskAllocCall->users()) { + if (MemCpyInst *memCpyInst = dyn_cast(in)) + EXPECT_EQ(memCpyInst->getDest(), TaskAllocCall); + } +} + +TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationIP) {}; + + auto *AllocaBB = Builder.GetInsertBlock(); + UnreachableInst *UI = Builder.CreateUnreachable(); + auto *BodyBB = Builder.GetInsertBlock()->splitBasicBlock(UI, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), + BodyGenCB)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + UI->eraseFromParent(); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + } // namespace