diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -867,12 +867,14 @@ /// \param Loc The source location description. /// \param BodyGenCB Callback that will generate the region code. /// \param FiniCB Callback to finalize variable copies. + /// \param IsNowait If false, a barrier is emitted. /// \param DidIt Local variable used as a flag to indicate 'single' thread /// /// \returns The insertion position *after* the single call. InsertPointTy createSingle(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, - FinalizeCallbackTy FiniCB, llvm::Value *DidIt); + FinalizeCallbackTy FiniCB, bool IsNowait, + llvm::Value *DidIt); /// Generator for '#omp master' /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -2739,10 +2739,9 @@ return Builder.saveIP(); } -OpenMPIRBuilder::InsertPointTy -OpenMPIRBuilder::createSingle(const LocationDescription &Loc, - BodyGenCallbackTy BodyGenCB, - FinalizeCallbackTy FiniCB, llvm::Value *DidIt) { +OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle( + const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, + FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) { if (!updateToLocation(Loc)) return Loc.IP; @@ -2770,9 +2769,16 @@ // .... single region ... // __kmpc_end_single // } - - return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB, - /*Conditional*/ true, /*hasFinalize*/ true); + // __kmpc_barrier + + EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB, + /*Conditional*/ true, + /*hasFinalize*/ true); + if (!IsNowait) + createBarrier(LocationDescription(Builder.saveIP(), Loc.DL), + omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false, + /* CheckCancelFlag */ false); + return Builder.saveIP(); } OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical( diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -2821,8 +2821,8 @@ EXPECT_NE(IPBB->end(), IP.getPoint()); }; - Builder.restoreIP( - OMPBuilder.createSingle(Builder, BodyGenCB, FiniCB, /*DidIt*/ nullptr)); + Builder.restoreIP(OMPBuilder.createSingle( + Builder, BodyGenCB, FiniCB, /*IsNowait*/ false, /*DidIt*/ nullptr)); Value *EntryBBTI = EntryBB->getTerminator(); EXPECT_NE(EntryBBTI, nullptr); EXPECT_TRUE(isa(EntryBBTI)); @@ -2854,6 +2854,106 @@ EXPECT_EQ(SingleEndCI->arg_size(), 2U); EXPECT_TRUE(isa(SingleEndCI->getArgOperand(0))); EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1)); + + bool FoundBarrier = false; + for (auto &FI : *ExitBB) { + Instruction *cur = &FI; + if (auto CI = dyn_cast(cur)) { + if (CI->getCalledFunction()->getName() == "__kmpc_barrier") { + FoundBarrier = true; + break; + } + } + } + EXPECT_TRUE(FoundBarrier); +} + +TEST_F(OpenMPIRBuilderTest, SingleDirectiveNowait) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + AllocaInst *PrivAI = nullptr; + + BasicBlock *EntryBB = nullptr; + BasicBlock *ExitBB = nullptr; + BasicBlock *ThenBB = nullptr; + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + BasicBlock &FiniBB) { + if (AllocaIP.isSet()) + Builder.restoreIP(AllocaIP); + else + Builder.SetInsertPoint(&*(F->getEntryBlock().getFirstInsertionPt())); + PrivAI = Builder.CreateAlloca(F->arg_begin()->getType()); + Builder.CreateStore(F->arg_begin(), PrivAI); + + llvm::BasicBlock *CodeGenIPBB = CodeGenIP.getBlock(); + llvm::Instruction *CodeGenIPInst = &*CodeGenIP.getPoint(); + EXPECT_EQ(CodeGenIPBB->getTerminator(), CodeGenIPInst); + + Builder.restoreIP(CodeGenIP); + + // collect some info for checks later + ExitBB = FiniBB.getUniqueSuccessor(); + ThenBB = Builder.GetInsertBlock(); + EntryBB = ThenBB->getUniquePredecessor(); + + // simple instructions for body + Value *PrivLoad = + Builder.CreateLoad(PrivAI->getAllocatedType(), PrivAI, "local.use"); + Builder.CreateICmpNE(F->arg_begin(), PrivLoad); + }; + + auto FiniCB = [&](InsertPointTy IP) { + BasicBlock *IPBB = IP.getBlock(); + EXPECT_NE(IPBB->end(), IP.getPoint()); + }; + + Builder.restoreIP(OMPBuilder.createSingle( + Builder, BodyGenCB, FiniCB, /*IsNowait*/ true, /*DidIt*/ nullptr)); + Value *EntryBBTI = EntryBB->getTerminator(); + EXPECT_NE(EntryBBTI, nullptr); + EXPECT_TRUE(isa(EntryBBTI)); + BranchInst *EntryBr = cast(EntryBB->getTerminator()); + EXPECT_TRUE(EntryBr->isConditional()); + EXPECT_EQ(EntryBr->getSuccessor(0), ThenBB); + EXPECT_EQ(ThenBB->getUniqueSuccessor(), ExitBB); + EXPECT_EQ(EntryBr->getSuccessor(1), ExitBB); + + CmpInst *CondInst = cast(EntryBr->getCondition()); + EXPECT_TRUE(isa(CondInst->getOperand(0))); + + CallInst *SingleEntryCI = cast(CondInst->getOperand(0)); + EXPECT_EQ(SingleEntryCI->arg_size(), 2U); + EXPECT_EQ(SingleEntryCI->getCalledFunction()->getName(), "__kmpc_single"); + EXPECT_TRUE(isa(SingleEntryCI->getArgOperand(0))); + + CallInst *SingleEndCI = nullptr; + for (auto &FI : *ThenBB) { + Instruction *cur = &FI; + if (isa(cur)) { + SingleEndCI = cast(cur); + if (SingleEndCI->getCalledFunction()->getName() == "__kmpc_end_single") + break; + SingleEndCI = nullptr; + } + } + EXPECT_NE(SingleEndCI, nullptr); + EXPECT_EQ(SingleEndCI->arg_size(), 2U); + EXPECT_TRUE(isa(SingleEndCI->getArgOperand(0))); + EXPECT_EQ(SingleEndCI->getArgOperand(1), SingleEntryCI->getArgOperand(1)); + + for (auto &FI : *ExitBB) { + Instruction *cur = &FI; + if (auto CI = dyn_cast(cur)) { + EXPECT_FALSE(CI->getCalledFunction()->getName() == "__kmpc_barrier"); + } + } } TEST_F(OpenMPIRBuilderTest, OMPAtomicReadFlt) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -652,7 +652,7 @@ }; auto finiCB = [&](InsertPointTy codeGenIP) {}; builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle( - ompLoc, bodyCB, finiCB, /*DidIt=*/nullptr)); + ompLoc, bodyCB, finiCB, singleOp.nowait(), /*DidIt=*/nullptr)); return bodyGenStatus; } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -1828,6 +1828,10 @@ // CHECK-LABEL: @single // CHECK-SAME: (i32 %[[x:.*]], i32 %[[y:.*]], i32* %[[zaddr:.*]]) llvm.func @single(%x: i32, %y: i32, %zaddr: !llvm.ptr) { + // CHECK: %[[a:.*]] = sub i32 %[[x]], %[[y]] + %a = llvm.sub %x, %y : i32 + // CHECK: store i32 %[[a]], i32* %[[zaddr]] + llvm.store %a, %zaddr : !llvm.ptr // CHECK: call i32 @__kmpc_single omp.single { // CHECK: %[[z:.*]] = add i32 %[[x]], %[[y]] @@ -1835,8 +1839,40 @@ // CHECK: store i32 %[[z]], i32* %[[zaddr]] llvm.store %z, %zaddr : !llvm.ptr // CHECK: call void @__kmpc_end_single + // CHECK: call void @__kmpc_barrier omp.terminator } + // CHECK: %[[b:.*]] = mul i32 %[[x]], %[[y]] + %b = llvm.mul %x, %y : i32 + // CHECK: store i32 %[[b]], i32* %[[zaddr]] + llvm.store %b, %zaddr : !llvm.ptr + // CHECK: ret void + llvm.return +} + +// ----- + +// CHECK-LABEL: @single_nowait +// CHECK-SAME: (i32 %[[x:.*]], i32 %[[y:.*]], i32* %[[zaddr:.*]]) +llvm.func @single_nowait(%x: i32, %y: i32, %zaddr: !llvm.ptr) { + // CHECK: %[[a:.*]] = sub i32 %[[x]], %[[y]] + %a = llvm.sub %x, %y : i32 + // CHECK: store i32 %[[a]], i32* %[[zaddr]] + llvm.store %a, %zaddr : !llvm.ptr + // CHECK: call i32 @__kmpc_single + omp.single nowait { + // CHECK: %[[z:.*]] = add i32 %[[x]], %[[y]] + %z = llvm.add %x, %y : i32 + // CHECK: store i32 %[[z]], i32* %[[zaddr]] + llvm.store %z, %zaddr : !llvm.ptr + // CHECK: call void @__kmpc_end_single + // CHECK-NOT: call void @__kmpc_barrier + omp.terminator + } + // CHECK: %[[t:.*]] = mul i32 %[[x]], %[[y]] + %t = llvm.mul %x, %y : i32 + // CHECK: store i32 %[[t]], i32* %[[zaddr]] + llvm.store %t, %zaddr : !llvm.ptr // CHECK: ret void llvm.return }