diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -691,11 +691,11 @@ if (&V == TIDAddr || &V == ZeroAddr) return; - SmallVector Uses; + SetVector Uses; for (Use &U : V.uses()) if (auto *UserI = dyn_cast(U.getUser())) if (ParallelRegionBlockSet.count(UserI->getParent())) - Uses.push_back(&U); + Uses.insert(&U); Value *ReplacementValue = nullptr; CallInst *CI = dyn_cast(&V); @@ -706,10 +706,56 @@ PrivCB(InnerAllocaIP, Builder.saveIP(), V, ReplacementValue)); assert(ReplacementValue && "Expected copy/create callback to set replacement value!"); - if (ReplacementValue == &V) - return; } + // __kmpc_fork_call expects extra arguments as pointers. If the input + // already has a pointer type, everything is fine, only use the replacement + // value inside the function. + if (V.getType()->isPointerTy()) { + if (ReplacementValue != &V) + for (Use *UPtr : Uses) + UPtr->set(ReplacementValue); + + return; + } + + // Otherwise, store the value onto stack and load it back inside the + // to-be-outlined region. This will ensure only the pointer will be passed + // to the function. + LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n"); + + // Find new uses created by the privatization. + SmallVector PrivatizationUses; + for (Use &U : V.uses()) { + if (Uses.contains(&U)) + continue; + if (auto *UserI = dyn_cast(U.getUser())) + if (ParallelRegionBlockSet.count(UserI->getParent())) + PrivatizationUses.push_back(&U); + } + + // Store to stack at end of the block that currently branches to the entry + // block of the to-be-outlined region. + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(ThenBB->getTerminator()); + Value *Ptr = Builder.CreateAlloca(V.getType()); + Builder.CreateStore(&V, Ptr); + + // Load back next to allocations in the to-be-outlined region. + Builder.restoreIP(InnerAllocaIP); + Value *Reloaded = Builder.CreateLoad(Ptr); + + // Any uses of the original value introduced by the privatization callback + // should use the loaded-back value instead. + for (Use *UPtr : PrivatizationUses) + UPtr->set(Reloaded); + + // Replace original uses of the value with the replacement value. If the + // callback returned the original value, use the loaded-back value instead + // because we all uses of the original value in the to-be-outlined region + // must disappear. + if (ReplacementValue == &V) + ReplacementValue = Reloaded; for (Use *UPtr : Uses) UPtr->set(ReplacementValue); }; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -829,6 +829,81 @@ } } +TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + Type *i32Ty = Type::getInt32Ty(M->getContext()); + Type *i32PtrTy = Type::getInt32PtrTy(M->getContext()); + Type *structTy = StructureType::get(i32Ty, i32PtrTy); + Type *structPtrTy = structTy->getPointerTo(); + Type *VoidTy = Type::getVoidTy(M->getContext()); + FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", i32Ty); + FunctionCallee TakeI32Func = + M->getOrInsertFunction("take_i32", VoidTy, i32Ty); + FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", i32PtrTy); + FunctionCallee TakeI32PtrFunc = + M->getOrInsertFunction("take_i32ptr", VoidTy, i32PtrTy); + FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", structTy); + FunctionCallee TakeStructFunc = + M->getOrInsertFunction("take_struct", VoidTy, structTy); + FunctionCallee RetStructPtrFunc = + M->getOrInsertFunction("ret_structptr", structPtrTy); + FunctionCallee TakeStructPtrFunc = + M->getOrInsertFunction("take_structPtr", VoidTy, structPtrTy); + Value *I32Val = Builder.CreateCall(RetI32Func); + Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc); + Value *StructVal = Builder.CreateCall(RetStructFunc); + Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc); + + Instruction *Internal; + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationBB) { + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.restoreIP(CodeGenIP); + Internal = Builder.CreateCall(TakeI32Func, I32Val); + Builder.CreateCall(TakeI32PtrFunc, I32PtrVal); + Builder.CreateCall(TakeStructFunc, StructVal); + Builder.CreateCall(TakeStructPtrFunc, StructPtrVal); + Builder.CreateBr(&ContinuationBB); + }; + auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + Value &VPtr, Value *&ReplacementValue) { + ReplacementValue = &VPtr; + return CodeGenIP; + } auto FiniCB = [](InsertPointTy) {}; + + IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + IRBuilder<>::InsertPoint AfterIP = + OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, + nullptr, nullptr, OMP_PROC_BIND_default, false); + Builder.restoreIP(AfterIP); + Builder.CreateRetVoid(); + + OMPBuilder.finalize(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + Function *OutlinedFn = Internal->getFunction(); + + Type *arg2Type = OutlinedFn->getArg(2)->getType(); + EXPECT_TRUE(arg2Type->isPointerTy()); + EXPECT_EQ(arg2Type->getPointerElementType(), i32Ty); + + Type *arg3Type = OutlinedFn->getArg(3)->getType(); + EXPECT_EQ(arg3Type, i32PtrTy); + + Type *arg4Type = OutlinedFn->getArg(4)->getType(); + EXPECT_TRUE(arg4Type->isPointerTy()); + EXPECT_EQ(arg4Type->getPointerElementType(), structTy); + + Type *arg5Type = OutlinedFn->getArg(5)->getType(); + EXPECT_EQ(arg5Type, structPtrTy); +} + TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);