diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4112,8 +4112,8 @@ } static Function * -createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName, - SmallVectorImpl &Inputs, +createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + StringRef FuncName, SmallVectorImpl &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { SmallVector ParameterTypes; for (auto &Arg : Inputs) @@ -4130,8 +4130,17 @@ // Generate the region into the function. BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func); Builder.SetInsertPoint(EntryBB); + + // Insert target init call in the device compilation pass. + if (OMPBuilder.Config.isEmbedded()) + Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false)); + Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP())); + // Insert target deinit call in the device compilation pass. + if (OMPBuilder.Config.isEmbedded()) + OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false); + // Insert return instruction. Builder.CreateRetVoid(); @@ -4161,8 +4170,9 @@ OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = - [&Builder, &Inputs, &CBFunc](StringRef EntryFnName) { - return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc); + [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs, + CBFunc); }; Constant *OutlinedFnID; @@ -4173,7 +4183,7 @@ static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn, SmallVectorImpl &Args) { - // TODO: Add kernel launch call when device codegen is supported. + // TODO: Add kernel launch call Builder.CreateCall(OutlinedFn, Args); } @@ -4189,7 +4199,8 @@ Function *OutlinedFn; emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams, NumThreads, Args, CBFunc); - emitTargetCall(Builder, OutlinedFn, Args); + if (!Config.isEmbedded()) + emitTargetCall(Builder, OutlinedFn, Args); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DIBuilder.h" @@ -5175,6 +5176,94 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.setConfig(OpenMPIRBuilderConfig(true, false, false, false)); + OMPBuilder.initialize(); + + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + StoreInst *TargetStore = nullptr; + llvm::SmallVector CapturedArgs = { + Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)), + Constant::getNullValue(Type::getInt32PtrTy(Ctx))}; + + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) + -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(CodeGenIP); + TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]); + return Builder.saveIP(); + }; + + IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, + /*Line=*/3, /*Count=*/0); + + Builder.restoreIP( + OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/-1, CapturedArgs, BodyGenCB)); + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + // Check outlined function + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_NE(TargetStore, nullptr); + Function *OutlinedFn = TargetStore->getFunction(); + EXPECT_NE(F, OutlinedFn); + + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); + EXPECT_EQ(OutlinedFn->arg_size(), 2U); + EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); + EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32)); + EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy()); + + // Check entry block + auto &EntryBlock = OutlinedFn->getEntryBlock(); + Instruction *Init = EntryBlock.getFirstNonPHI(); + EXPECT_NE(Init, nullptr); + + auto *InitCall = dyn_cast(Init); + EXPECT_NE(InitCall, nullptr); + EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init"); + EXPECT_EQ(InitCall->arg_size(), 3U); + EXPECT_TRUE(isa(InitCall->getArgOperand(0))); + EXPECT_EQ(InitCall->getArgOperand(1), + ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC)); + EXPECT_EQ(InitCall->getArgOperand(2), + ConstantInt::get(Type::getInt1Ty(Ctx), true)); + + auto *EntryBlockBranch = EntryBlock.getTerminator(); + EXPECT_NE(EntryBlockBranch, nullptr); + EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U); + + // Check user code block + auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0); + EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry"); + EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore); + + auto *Deinit = TargetStore->getNextNode(); + EXPECT_NE(Deinit, nullptr); + + auto *DeinitCall = dyn_cast(Deinit); + EXPECT_NE(DeinitCall, nullptr); + EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit"); + EXPECT_EQ(DeinitCall->arg_size(), 2U); + EXPECT_TRUE(isa(DeinitCall->getArgOperand(0))); + EXPECT_EQ(DeinitCall->getArgOperand(1), + ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC)); + + EXPECT_TRUE(isa(DeinitCall->getNextNode())); + + // Check exit block + auto *ExitBlock = EntryBlockBranch->getSuccessor(1); + EXPECT_EQ(ExitBlock->getName(), "worker.exit"); + EXPECT_TRUE(isa(ExitBlock->getFirstNonPHI())); +} + TEST_F(OpenMPIRBuilderTest, CreateTask) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1629,15 +1629,6 @@ if (!targetOpSupported(opInst)) return failure(); - bool isDevice = false; - if (auto offloadMod = dyn_cast( - opInst.getParentOfType().getOperation())) { - isDevice = offloadMod.getIsDevice(); - } - - if (isDevice) // TODO: Implement device codegen. - return success(); - auto targetOp = cast(opInst); auto &targetRegion = targetOp.getRegion();