diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4293,13 +4293,37 @@ return Builder.saveIP(); } +// Copy input from pointer or i64 to the expected argument type. +static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace, + Value *Input, Argument &Arg) { + auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy() + ? Arg.getType() + : Type::getInt64Ty(Builder.getContext()), + AddrSpace); + auto AddrAscast = + Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); + Builder.CreateStore(&Arg, AddrAscast); + auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast); + + return Copy; +} + static Function * createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName, SmallVectorImpl &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) { SmallVector ParameterTypes; - for (auto &Arg : Inputs) - ParameterTypes.push_back(Arg->getType()); + if (OMPBuilder.Config.isTargetDevice()) { + // All parameters to target devices are passed as pointers + // or i64. This assumes 64-bit address spaces/pointers. + for (auto &Arg : Inputs) + ParameterTypes.push_back(Arg->getType()->isPointerTy() + ? Arg->getType() + : Type::getInt64Ty(Builder.getContext())); + } else { + for (auto &Arg : Inputs) + ParameterTypes.push_back(Arg->getType()); + } auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes, /*isVarArg*/ false); @@ -4317,9 +4341,10 @@ if (OMPBuilder.Config.isTargetDevice()) Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false)); - Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP())); + BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock(); // Insert target deinit call in the device compilation pass. + Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP())); if (OMPBuilder.Config.isTargetDevice()) OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false); @@ -4327,15 +4352,23 @@ Builder.CreateRetVoid(); // Rewrite uses of input valus to parameters. + Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg()); for (auto InArg : zip(Inputs, Func->args())) { Value *Input = std::get<0>(InArg); Argument &Arg = std::get<1>(InArg); + Value *InputCopy = + OMPBuilder.Config.isTargetDevice() + ? copyInput(Builder, + OMPBuilder.M.getDataLayout().getAllocaAddrSpace(), + Input, Arg) + : &Arg; + // Collect all the instructions for (User *User : make_early_inc_range(Input->users())) if (auto Instr = dyn_cast(User)) if (Instr->getFunction() == Func) - Instr->replaceUsesOfWith(Input, &Arg); + Instr->replaceUsesOfWith(Input, InputCopy); } // Restore insert point. diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5122,16 +5122,18 @@ IRBuilder<> Builder(BB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + LoadInst *Value = nullptr; StoreInst *TargetStore = nullptr; llvm::SmallVector CapturedArgs = { - Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)), + Constant::getNullValue(Type::getInt32PtrTy(Ctx)), Constant::getNullValue(Type::getInt32PtrTy(Ctx))}; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { Builder.restoreIP(CodeGenIP); - TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]); + Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]); + TargetStore = Builder.CreateStore(Value, CapturedArgs[1]); return Builder.saveIP(); }; @@ -5155,7 +5157,7 @@ EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); EXPECT_EQ(OutlinedFn->arg_size(), 2U); EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); - EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32)); + EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isPointerTy()); EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy()); // Check entry block @@ -5180,8 +5182,22 @@ // Check user code block auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0); EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry"); - EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore); - + auto *Alloca1 = UserCodeBlock->getFirstNonPHI(); + EXPECT_TRUE(isa(Alloca1)); + auto *Store1 = Alloca1->getNextNode(); + EXPECT_TRUE(isa(Store1)); + auto *Load1 = Store1->getNextNode(); + EXPECT_TRUE(isa(Load1)); + auto *Alloca2 = Load1->getNextNode(); + EXPECT_TRUE(isa(Alloca2)); + auto *Store2 = Alloca2->getNextNode(); + EXPECT_TRUE(isa(Store2)); + auto *Load2 = Store2->getNextNode(); + EXPECT_TRUE(isa(Load2)); + + auto *Value1 = Load2->getNextNode(); + EXPECT_EQ(Value1, Value); + EXPECT_EQ(Value1->getNextNode(), TargetStore); auto *Deinit = TargetStore->getNextNode(); EXPECT_NE(Deinit, nullptr); diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir --- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir @@ -30,12 +30,21 @@ // CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %3, -1 // CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]] // CHECK: [[LABEL_ENTRY]]: +// CHECK: %[[TMP_A:.*]] = alloca ptr, align 8 +// CHECK: store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8 +// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8 +// CHECK: %[[TMP_B:.*]] = alloca ptr, align 8 +// CHECK: store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8 +// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8 +// CHECK: %[[TMP_C:.*]] = alloca ptr, align 8 +// CHECK: store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8 +// CHECK: %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8 // CHECK-NEXT: br label %[[LABEL_TARGET:.*]] // CHECK: [[LABEL_TARGET]]: -// CHECK: %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4 -// CHECK: %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4 +// CHECK: %[[A:.*]] = load i32, ptr %[[PTR_A]], align 4 +// CHECK: %[[B:.*]] = load i32, ptr %[[PTR_B]], align 4 // CHECK: %[[C:.*]] = add i32 %[[A]], %[[B]] -// CHECK: store i32 %[[C]], ptr %[[ADDR_C]], align 4 +// CHECK: store i32 %[[C]], ptr %[[PTR_C]], align 4 // CHECK: br label %[[LABEL_DEINIT:.*]] // CHECK: [[LABEL_DEINIT]]: // CHECK-NEXT: call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1)