diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1038,6 +1038,20 @@ SerializedParallelCallArgs); // OutlinedFn(>id, &zero, CapturedStruct); + // The extractor only filled the CapturedStruct in the fork-call site, + // do the same here by copying the stores. + Value *CapturedStruct = *(CI->arg_end() - 1); + for (User *Usr : CapturedStruct->users()) { + GetElementPtrInst *GEP = dyn_cast(Usr); + if (!GEP) + continue; + // Find the value stored to the Struct. + StoreInst *StoreToStruct = dyn_cast(*GEP->user_begin()); + Value *StoredAggValue = StoreToStruct->getValueOperand(); + // Clone the GEP and create a Store to the Struct. + Value *CloneGEP = Builder.Insert(GEP->clone()); + Builder.CreateStore(StoredAggValue, CloneGEP); + } CI->removeFromParent(); Builder.Insert(CI); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -260,10 +260,15 @@ } // Returns the value stored in the aggregate argument of an outlined function, -// or nullptr if it is not found. +// or nullptr if it is not found. For parallel regions with an if clause, there +// will be two GEP users (one in the fork region and the other in the serialized +// parallel region) for the AggregateStruct. In these cases, pass the position +// of the user as the last argument. static Value *findStoredValueInAggregateAt(LLVMContext &Ctx, Value *Aggregate, - unsigned Idx) { + unsigned Idx, + int GEPUserNumber = 0) { GetElementPtrInst *GEPAtIdx = nullptr; + int CurGEPUserNumber = 0; // Find GEP instruction at that index. for (User *Usr : Aggregate->users()) { GetElementPtrInst *GEP = dyn_cast(Usr); @@ -273,8 +278,17 @@ if (GEP->getOperand(2) != ConstantInt::get(Type::getInt32Ty(Ctx), Idx)) continue; - EXPECT_EQ(GEPAtIdx, nullptr); - GEPAtIdx = GEP; + CurGEPUserNumber++; + + if (GEPUserNumber > 0) { + EXPECT_LE(CurGEPUserNumber, 2); + EXPECT_LE(GEPUserNumber, 2); + if (CurGEPUserNumber == GEPUserNumber) + GEPAtIdx = GEP; + } else { + EXPECT_EQ(GEPAtIdx, nullptr); + GEPAtIdx = GEP; + } } EXPECT_NE(GEPAtIdx, nullptr); @@ -1003,17 +1017,18 @@ EXPECT_TRUE(isa(ForkCI->getArgOperand(0))); EXPECT_EQ(ForkCI->getArgOperand(1), ConstantInt::get(Type::getInt32Ty(Ctx), 1)); - Value *StoredForkArg = - findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0); + Value *StoredForkArg = findStoredValueInAggregateAt( + Ctx, ForkCI->getArgOperand(3), 0, /*GEPUserNumber=*/2); EXPECT_EQ(StoredForkArg, F->arg_begin()); EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn); EXPECT_EQ(DirectCI->arg_size(), 3U); EXPECT_TRUE(isa(DirectCI->getArgOperand(0))); EXPECT_TRUE(isa(DirectCI->getArgOperand(1))); - Value *StoredDirectArg = - findStoredValueInAggregateAt(Ctx, DirectCI->getArgOperand(2), 0); + Value *StoredDirectArg = findStoredValueInAggregateAt( + Ctx, DirectCI->getArgOperand(2), 0, /*GEPUserNumber=*/1); EXPECT_EQ(StoredDirectArg, F->arg_begin()); + EXPECT_EQ(StoredForkArg, StoredDirectArg); } TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {