diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4148,8 +4148,8 @@ } static Function * -createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName, - SmallVectorImpl &Inputs, +createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + StringRef FuncName, SmallVectorImpl &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { SmallVector ParameterTypes; for (auto &Arg : Inputs) @@ -4166,8 +4166,17 @@ // Generate the region into the function. BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func); Builder.SetInsertPoint(EntryBB); + + // Insert target init call in the device compilation pass. + if (OMPBuilder.Config.isEmbedded()) + Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false)); + Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP())); + // Insert target deinit call in the device compilation pass. + if (OMPBuilder.Config.isEmbedded()) + OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false); + // Insert return instruction. Builder.CreateRetVoid(); @@ -4197,8 +4206,9 @@ OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = - [&Builder, &Inputs, &CBFunc](StringRef EntryFnName) { - return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc); + [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) { + return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs, + CBFunc); }; Constant *OutlinedFnID; @@ -4209,7 +4219,7 @@ static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn, SmallVectorImpl &Args) { - // TODO: Add kernel launch call when device codegen is supported. + // TODO: Add kernel launch call Builder.CreateCall(OutlinedFn, Args); } @@ -4225,7 +4235,8 @@ Function *OutlinedFn; emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams, NumThreads, Args, CBFunc); - emitTargetCall(Builder, OutlinedFn, Args); + if (!Config.isEmbedded()) + emitTargetCall(Builder, OutlinedFn, Args); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DIBuilder.h" @@ -5175,6 +5176,94 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.setConfig(OpenMPIRBuilderConfig(true, false, false, false)); + OMPBuilder.initialize(); + + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + StoreInst *TargetStore = nullptr; + llvm::SmallVector CapturedArgs = { + Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)), + Constant::getNullValue(Type::getInt32PtrTy(Ctx))}; + + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::InsertPointTy CodeGenIP) + -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(CodeGenIP); + TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]); + return Builder.saveIP(); + }; + + IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, + /*Line=*/3, /*Count=*/0); + + Builder.restoreIP( + OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1, + /*NumThreads=*/-1, CapturedArgs, BodyGenCB)); + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + // Check outlined function + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_NE(TargetStore, nullptr); + Function *OutlinedFn = TargetStore->getFunction(); + EXPECT_NE(F, OutlinedFn); + + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); + EXPECT_EQ(OutlinedFn->arg_size(), 2U); + EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); + EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32)); + EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy()); + + // Check entry block + auto &EntryBlock = OutlinedFn->getEntryBlock(); + Instruction *Init = EntryBlock.getFirstNonPHI(); + EXPECT_NE(Init, nullptr); + + auto *InitCall = dyn_cast(Init); + EXPECT_NE(InitCall, nullptr); + EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init"); + EXPECT_EQ(InitCall->arg_size(), 3U); + EXPECT_TRUE(isa(InitCall->getArgOperand(0))); + EXPECT_EQ(InitCall->getArgOperand(1), + ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC)); + EXPECT_EQ(InitCall->getArgOperand(2), + ConstantInt::get(Type::getInt1Ty(Ctx), true)); + + auto *EntryBlockBranch = EntryBlock.getTerminator(); + EXPECT_NE(EntryBlockBranch, nullptr); + EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U); + + // Check user code block + auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0); + EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry"); + EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore); + + auto *Deinit = TargetStore->getNextNode(); + EXPECT_NE(Deinit, nullptr); + + auto *DeinitCall = dyn_cast(Deinit); + EXPECT_NE(DeinitCall, nullptr); + EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit"); + EXPECT_EQ(DeinitCall->arg_size(), 2U); + EXPECT_TRUE(isa(DeinitCall->getArgOperand(0))); + EXPECT_EQ(DeinitCall->getArgOperand(1), + ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC)); + + EXPECT_TRUE(isa(DeinitCall->getNextNode())); + + // Check exit block + auto *ExitBlock = EntryBlockBranch->getSuccessor(1); + EXPECT_EQ(ExitBlock->getName(), "worker.exit"); + EXPECT_TRUE(isa(ExitBlock->getFirstNonPHI())); +} + TEST_F(OpenMPIRBuilderTest, CreateTask) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -1629,15 +1629,6 @@ if (!targetOpSupported(opInst)) return failure(); - bool isDevice = false; - if (auto offloadMod = dyn_cast( - opInst.getParentOfType().getOperation())) { - isDevice = offloadMod.getIsDevice(); - } - - if (isDevice) // TODO: Implement device codegen. - return success(); - auto targetOp = cast(opInst); auto &targetRegion = targetOp.getRegion(); diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {omp.is_device = #omp.isdevice} { + llvm.func @omp_target_region_() { + %0 = llvm.mlir.constant(20 : i32) : i32 + %1 = llvm.mlir.constant(10 : i32) : i32 + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr + llvm.store %1, %3 : !llvm.ptr + llvm.store %0, %5 : !llvm.ptr + omp.target { + %8 = llvm.load %3 : !llvm.ptr + %9 = llvm.load %5 : !llvm.ptr + %10 = llvm.add %8, %9 : i32 + llvm.store %10, %7 : !llvm.ptr + omp.terminator + } + llvm.return + } +} + +// CHECK: @[[SRC_LOC:.*]] = private unnamed_addr constant [23 x i8] c"{{[^"]*}}", align 1 +// CHECK: @[[IDENT:.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @[[SRC_LOC]] }, align 8 +// CHECK: define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]]) +// CHECK: %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[IDENT]], i8 1, i1 true) +// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %3, -1 +// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] +// CHECK: [[LABEL_ENTRY]]: +// CHECK-NEXT: br label %[[LABEL_TARGET:.*]] +// CHECK: [[LABEL_TARGET]]: +// CHECK: %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4 +// CHECK: %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4 +// CHECK: %[[C:.*]] = add i32 %[[A]], %[[B]] +// CHECK: store i32 %[[C]], ptr %[[ADDR_C]], align 4 +// CHECK: br label %[[LABEL_DEINIT:.*]] +// CHECK: [[LABEL_DEINIT]]: +// CHECK-NEXT: call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1) +// CHECK-NEXT: ret void +// CHECK: [[LABEL_EXIT]]: +// CHECK-NEXT: ret void