diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -522,7 +522,8 @@ // Add some fake uses for OpenMP provided arguments. ToBeDeleted.push_back(Builder.CreateLoad(TIDAddr, "tid.addr.use")); - ToBeDeleted.push_back(Builder.CreateLoad(ZeroAddr, "zero.addr.use")); + Instruction *ZeroAddrUse = Builder.CreateLoad(ZeroAddr, "zero.addr.use"); + ToBeDeleted.push_back(ZeroAddrUse); // ThenBB // | @@ -687,15 +688,41 @@ FunctionCallee TIDRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num); + // Define the insertion point for loading the values wrapped into pointers for + // passing into the to-be-outlined region. Insert them immediately after the + // fake use of zero address so that they are available in the generated body + // and so that the OpenMP-related values (thread ID and zero address pointers) + // remain leading in the argument list. + IRBuilder<>::InsertPoint ReloadIP(ZeroAddrUse->getParent(), + ZeroAddrUse->getNextNode()->getIterator()); + auto PrivHelper = [&](Value &V) { if (&V == TIDAddr || &V == ZeroAddr) return; - SmallVector Uses; + SetVector Uses; for (Use &U : V.uses()) if (auto *UserI = dyn_cast(U.getUser())) if (ParallelRegionBlockSet.count(UserI->getParent())) - Uses.push_back(&U); + Uses.insert(&U); + + Value *Reloaded = nullptr; + if (!V.getType()->isPointerTy()) { + IRBuilder<>::InsertPointGuard Guard(Builder); + + // Store to stack at end of the block that currently branches to the entry + // block of the to-be-outlined region. + Builder.SetInsertPoint(InsertBB, + InsertBB->getTerminator()->getIterator()); + Value *Ptr = + Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded"); + Builder.CreateStore(&V, Ptr); + + // Load back next to allocations in the to-be-outlined region. + Builder.restoreIP(ReloadIP); + Reloaded = Builder.CreateLoad(Ptr); + InnerAllocaIP = Builder.saveIP(); + } Value *ReplacementValue = nullptr; CallInst *CI = dyn_cast(&V); @@ -706,10 +733,47 @@ PrivCB(InnerAllocaIP, Builder.saveIP(), V, ReplacementValue)); assert(ReplacementValue && "Expected copy/create callback to set replacement value!"); - if (ReplacementValue == &V) - return; } + // __kmpc_fork_call expects extra arguments as pointers. If the input + // already has a pointer type, everything is fine, only use the replacement + // value inside the function. This also captures the TID case because it is + // passed in as a pointer. + if (V.getType()->isPointerTy()) { + if (ReplacementValue != &V) + for (Use *UPtr : Uses) + UPtr->set(ReplacementValue); + + return; + } + + // Otherwise, store the value onto stack and load it back inside the + // to-be-outlined region. This will ensure only the pointer will be passed + // to the function. + LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n"); + assert(Reloaded && "Expected non-pointer argument to be loaded back!"); + + // Find new uses created by the privatization. + SmallVector PrivatizationUses; + for (Use &U : V.uses()) { + if (Uses.contains(&U)) + continue; + if (auto *UserI = dyn_cast(U.getUser())) + if (ParallelRegionBlockSet.count(UserI->getParent())) + PrivatizationUses.push_back(&U); + } + + // Any uses of the original value introduced by the privatization callback + // should use the loaded-back value instead. + for (Use *UPtr : PrivatizationUses) + UPtr->set(Reloaded); + + // Replace original uses of the value with the replacement value. If the + // callback returned the original value, use the loaded-back value instead + // because all uses of the original value in the to-be-outlined region must + // disappear. + if (ReplacementValue == &V) + ReplacementValue = Reloaded; for (Use *UPtr : Uses) UPtr->set(ReplacementValue); }; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -60,6 +60,25 @@ DebugLoc DL; }; +// Returns the value stored in the given allocation. Returns null if the given +// value is not a result of an allocation, if no value is stored or if there is +// more than one store. +static Value *findStoredValue(Value *AllocaValue) { + Instruction *Alloca = dyn_cast(AllocaValue); + if (!Alloca) + return nullptr; + StoreInst *Store = nullptr; + for (Use &U : Alloca->uses()) { + if (auto *CandidateStore = dyn_cast(U.getUser())) { + EXPECT_EQ(Store, nullptr); + Store = CandidateStore; + } + } + if (!Store) + return nullptr; + return Store->getValueOperand(); +}; + TEST_F(OpenMPIRBuilderTest, CreateBarrier) { OpenMPIRBuilder OMPBuilder(*M); OMPBuilder.initialize(); @@ -401,7 +420,7 @@ EXPECT_EQ(ForkCI->getArgOperand(1), ConstantInt::get(Type::getInt32Ty(Ctx), 1U)); EXPECT_EQ(ForkCI->getArgOperand(2), Usr); - EXPECT_EQ(ForkCI->getArgOperand(3), F->arg_begin()); + EXPECT_EQ(findStoredValue(ForkCI->getArgOperand(3)), F->arg_begin()); } TEST_F(OpenMPIRBuilderTest, ParallelNested) { @@ -708,13 +727,15 @@ EXPECT_TRUE(isa(ForkCI->getArgOperand(0))); EXPECT_EQ(ForkCI->getArgOperand(1), ConstantInt::get(Type::getInt32Ty(Ctx), 1)); - EXPECT_EQ(ForkCI->getArgOperand(3), F->arg_begin()); + Value *StoredForkArg = findStoredValue(ForkCI->getArgOperand(3)); + EXPECT_EQ(StoredForkArg, F->arg_begin()); EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn); EXPECT_EQ(DirectCI->getNumArgOperands(), 3U); EXPECT_TRUE(isa(DirectCI->getArgOperand(0))); EXPECT_TRUE(isa(DirectCI->getArgOperand(1))); - EXPECT_EQ(DirectCI->getArgOperand(2), F->arg_begin()); + Value *StoredDirectArg = findStoredValue(DirectCI->getArgOperand(2)); + EXPECT_EQ(StoredDirectArg, F->arg_begin()); } TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) { @@ -829,6 +850,85 @@ } } +TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + Type *I32Ty = Type::getInt32Ty(M->getContext()); + Type *I32PtrTy = Type::getInt32PtrTy(M->getContext()); + Type *StructTy = StructType::get(I32Ty, I32PtrTy); + Type *StructPtrTy = StructTy->getPointerTo(); + Type *VoidTy = Type::getVoidTy(M->getContext()); + FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty); + FunctionCallee TakeI32Func = + M->getOrInsertFunction("take_i32", VoidTy, I32Ty); + FunctionCallee RetI32PtrFunc = M->getOrInsertFunction("ret_i32ptr", I32PtrTy); + FunctionCallee TakeI32PtrFunc = + M->getOrInsertFunction("take_i32ptr", VoidTy, I32PtrTy); + FunctionCallee RetStructFunc = M->getOrInsertFunction("ret_struct", StructTy); + FunctionCallee TakeStructFunc = + M->getOrInsertFunction("take_struct", VoidTy, StructTy); + FunctionCallee RetStructPtrFunc = + M->getOrInsertFunction("ret_structptr", StructPtrTy); + FunctionCallee TakeStructPtrFunc = + M->getOrInsertFunction("take_structPtr", VoidTy, StructPtrTy); + Value *I32Val = Builder.CreateCall(RetI32Func); + Value *I32PtrVal = Builder.CreateCall(RetI32PtrFunc); + Value *StructVal = Builder.CreateCall(RetStructFunc); + Value *StructPtrVal = Builder.CreateCall(RetStructPtrFunc); + + Instruction *Internal; + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &ContinuationBB) { + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.restoreIP(CodeGenIP); + Internal = Builder.CreateCall(TakeI32Func, I32Val); + Builder.CreateCall(TakeI32PtrFunc, I32PtrVal); + Builder.CreateCall(TakeStructFunc, StructVal); + Builder.CreateCall(TakeStructPtrFunc, StructPtrVal); + }; + auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + Value &VPtr, Value *&ReplacementValue) { + ReplacementValue = &VPtr; + return CodeGenIP; + }; + auto FiniCB = [](InsertPointTy) {}; + + IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + IRBuilder<>::InsertPoint AfterIP = + OMPBuilder.createParallel(Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, + nullptr, nullptr, OMP_PROC_BIND_default, false); + Builder.restoreIP(AfterIP); + Builder.CreateRetVoid(); + + OMPBuilder.finalize(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + Function *OutlinedFn = Internal->getFunction(); + + Type *Arg2Type = OutlinedFn->getArg(2)->getType(); + EXPECT_TRUE(Arg2Type->isPointerTy()); + EXPECT_EQ(Arg2Type->getPointerElementType(), I32Ty); + + // Arguments that need to be passed through pointers and reloaded will get + // used earlier in the functions and therefore will appear first in the + // argument list after outlining. + Type *Arg3Type = OutlinedFn->getArg(3)->getType(); + EXPECT_TRUE(Arg3Type->isPointerTy()); + EXPECT_EQ(Arg3Type->getPointerElementType(), StructTy); + + Type *Arg4Type = OutlinedFn->getArg(4)->getType(); + EXPECT_EQ(Arg4Type, I32PtrTy); + + Type *Arg5Type = OutlinedFn->getArg(5)->getType(); + EXPECT_EQ(Arg5Type, StructPtrTy); +} + TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);