diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -139,6 +139,17 @@ bool ForceSimpleCall = false, bool CheckCancelFlag = true); + /// Generator for '#omp cancel' + /// + /// \param Loc The location where the directive was encountered. + /// \param IfCondition The evaluated 'if' clause expression, if any. + /// \param CanceledDirective The kind of directive that is cancled. + /// + /// \returns The insertion point after the barrier. + InsertPointTy CreateCancel(const LocationDescription &Loc, + Value *IfCondition, + omp::Directive CanceledDirective); + /// Generator for '#omp parallel' /// /// \param Loc The insert and source location description. @@ -183,6 +194,13 @@ Value *getOrCreateIdent(Constant *SrcLocStr, omp::IdentFlag Flags = omp::IdentFlag(0)); + /// Generate control flow and cleanup for cancellation. + /// + /// \param CancelFlag Flag indicating if the cancellation is performed. + /// \param CanceledDirective The kind of directive that is cancled. + void emitCancelationCheckImpl(Value *CancelFlag, + omp::Directive CanceledDirective); + /// Generate a barrier runtime call. /// /// \param Loc The location at which the request originated and is fulfilled. diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -164,11 +164,13 @@ OMP_RTL(OMPRTL_##Name, #Name, IsVarArg, ReturnType, __VA_ARGS__) __OMP_RTL(__kmpc_barrier, false, Void, IdentPtr, Int32) +__OMP_RTL(__kmpc_cancel, false, Int32, IdentPtr, Int32, Int32) __OMP_RTL(__kmpc_cancel_barrier, false, Int32, IdentPtr, Int32) __OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr) __OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr) -__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32, /* Int */Int32) -__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */Int32) +__OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32, + /* Int */ Int32) +__OMP_RTL(__kmpc_push_proc_bind, false, Void, IdentPtr, Int32, /* Int */ Int32) __OMP_RTL(__kmpc_serialized_parallel, false, Void, IdentPtr, Int32) __OMP_RTL(__kmpc_end_serialized_parallel, false, Void, IdentPtr, Int32) @@ -240,6 +242,26 @@ ///} +/// KMP cancel kind +/// +///{ + +#ifndef OMP_CANCEL_KIND +#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) +#endif + +#define __OMP_CANCEL_KIND(Name, Value) \ + OMP_CANCEL_KIND(OMP_CANCEL_KIND_##Name, #Name, OMPD_##Name, Value) + +__OMP_CANCEL_KIND(parallel, 1) +__OMP_CANCEL_KIND(for, 2) +__OMP_CANCEL_KIND(sections, 3) +__OMP_CANCEL_KIND(taskgroup, 4) + +#undef __OMP_CANCEL_KIND +#undef OMP_CANCEL_KIND + +///} /// Proc bind kinds /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -216,41 +216,90 @@ : OMPRTL___kmpc_barrier), Args); - if (UseCancelBarrier && CheckCancelFlag) { - // For a cancel barrier we create two new blocks. - BasicBlock *BB = Builder.GetInsertBlock(); - BasicBlock *NonCancellationBlock; - if (Builder.GetInsertPoint() == BB->end()) { - // TODO: This branch will not be needed once we moved to the - // OpenMPIRBuilder codegen completely. - NonCancellationBlock = BasicBlock::Create( - BB->getContext(), BB->getName() + ".cont", BB->getParent()); - } else { - NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint()); - BB->getTerminator()->eraseFromParent(); - Builder.SetInsertPoint(BB); - } - BasicBlock *CancellationBlock = BasicBlock::Create( - BB->getContext(), BB->getName() + ".cncl", BB->getParent()); - - // Jump to them based on the return value. - Value *Cmp = Builder.CreateIsNull(Result); - Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock, - /* TODO weight */ nullptr, nullptr); - - // From the cancellation block we finalize all variables and go to the - // post finalization block that is known to the FiniCB callback. - Builder.SetInsertPoint(CancellationBlock); - auto &FI = FinalizationStack.back(); - FI.FiniCB(Builder.saveIP()); - - // The continuation block is where code generation continues. - Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin()); + if (UseCancelBarrier && CheckCancelFlag) + emitCancelationCheckImpl(Result, OMPD_parallel); + + return Builder.saveIP(); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::CreateCancel(const LocationDescription &Loc, + Value *IfCondition, + omp::Directive CanceledDirective) { + if (!updateToLocation(Loc)) + return Loc.IP; + + // LLVM utilities like blocks with terminators. + auto *UI = Builder.CreateUnreachable(); + + Instruction *ThenTI = UI, *ElseTI = nullptr; + if (IfCondition) + SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI); + Builder.SetInsertPoint(ThenTI); + + Value *CancelKind = nullptr; + switch (CanceledDirective) { +#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \ + case DirectiveEnum: \ + CancelKind = Builder.getInt32(Value); \ + break; +#include "llvm/Frontend/OpenMP/OMPKinds.def" + default: + llvm_unreachable("Unknown cancel kind!"); } + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind}; + Value *Result = Builder.CreateCall( + getOrCreateRuntimeFunction(OMPRTL___kmpc_cancel), Args); + + // The actual cancel logic is shared with others, e.g., cancel_barriers. + emitCancelationCheckImpl(Result, CanceledDirective); + + // Update the insertion point and remove the terminator we introduced. + Builder.SetInsertPoint(UI->getParent()); + UI->eraseFromParent(); + return Builder.saveIP(); } +void OpenMPIRBuilder::emitCancelationCheckImpl( + Value *CancelFlag, omp::Directive CanceledDirective) { + assert(isLastFinalizationInfoCancellable(CanceledDirective) && + "Unexpected cancellation!"); + + // For a cancel barrier we create two new blocks. + BasicBlock *BB = Builder.GetInsertBlock(); + BasicBlock *NonCancellationBlock; + if (Builder.GetInsertPoint() == BB->end()) { + // TODO: This branch will not be needed once we moved to the + // OpenMPIRBuilder codegen completely. + NonCancellationBlock = BasicBlock::Create( + BB->getContext(), BB->getName() + ".cont", BB->getParent()); + } else { + NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint()); + BB->getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(BB); + } + BasicBlock *CancellationBlock = BasicBlock::Create( + BB->getContext(), BB->getName() + ".cncl", BB->getParent()); + + // Jump to them based on the return value. + Value *Cmp = Builder.CreateIsNull(CancelFlag); + Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock, + /* TODO weight */ nullptr, nullptr); + + // From the cancellation block we finalize all variables and go to the + // post finalization block that is known to the FiniCB callback. + Builder.SetInsertPoint(CancellationBlock); + auto &FI = FinalizationStack.back(); + FI.FiniCB(Builder.saveIP()); + + // The continuation block is where code generation continues. + Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin()); +} + IRBuilder<>::InsertPoint OpenMPIRBuilder::CreateParallel( const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB, FinalizeCallbackTy FiniCB, Value *IfCondition, diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -99,6 +99,122 @@ EXPECT_FALSE(verifyModule(*M)); } +TEST_F(OpenMPIRBuilderTest, CreateCancel) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + + BasicBlock *CBB = BasicBlock::Create(Ctx, "", F); + new UnreachableInst(Ctx, CBB); + auto FiniCB = [&](InsertPointTy IP) { + ASSERT_NE(IP.getBlock(), nullptr); + ASSERT_EQ(IP.getBlock()->end(), IP.getPoint()); + BranchInst::Create(CBB, IP.getBlock()); + }; + OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true}); + + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()}); + auto NewIP = OMPBuilder.CreateCancel(Loc, nullptr, OMPD_parallel); + Builder.restoreIP(NewIP); + EXPECT_FALSE(M->global_empty()); + EXPECT_EQ(M->size(), 3U); + EXPECT_EQ(F->size(), 4U); + EXPECT_EQ(BB->size(), 4U); + + CallInst *GTID = dyn_cast(&BB->front()); + EXPECT_NE(GTID, nullptr); + EXPECT_EQ(GTID->getNumArgOperands(), 1U); + EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num"); + EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory()); + EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory()); + + CallInst *Cancel = dyn_cast(GTID->getNextNode()); + EXPECT_NE(Cancel, nullptr); + EXPECT_EQ(Cancel->getNumArgOperands(), 3U); + EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel"); + EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory()); + EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory()); + EXPECT_EQ(Cancel->getNumUses(), 1U); + Instruction *CancelBBTI = Cancel->getParent()->getTerminator(); + EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U); + EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock()); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(), + 1U); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), + CBB); + + EXPECT_EQ(cast(Cancel)->getArgOperand(1), GTID); + + OMPBuilder.popFinalizationCB(); + + Builder.CreateUnreachable(); + EXPECT_FALSE(verifyModule(*M)); +} + +TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + + BasicBlock *CBB = BasicBlock::Create(Ctx, "", F); + new UnreachableInst(Ctx, CBB); + auto FiniCB = [&](InsertPointTy IP) { + ASSERT_NE(IP.getBlock(), nullptr); + ASSERT_EQ(IP.getBlock()->end(), IP.getPoint()); + BranchInst::Create(CBB, IP.getBlock()); + }; + OMPBuilder.pushFinalizationCB({FiniCB, OMPD_parallel, true}); + + IRBuilder<> Builder(BB); + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP()}); + auto NewIP = OMPBuilder.CreateCancel(Loc, Builder.getTrue(), OMPD_parallel); + Builder.restoreIP(NewIP); + EXPECT_FALSE(M->global_empty()); + EXPECT_EQ(M->size(), 3U); + EXPECT_EQ(F->size(), 7U); + EXPECT_EQ(BB->size(), 1U); + ASSERT_TRUE(isa(BB->getTerminator())); + ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U); + BB = BB->getTerminator()->getSuccessor(0); + EXPECT_EQ(BB->size(), 4U); + + + CallInst *GTID = dyn_cast(&BB->front()); + EXPECT_NE(GTID, nullptr); + EXPECT_EQ(GTID->getNumArgOperands(), 1U); + EXPECT_EQ(GTID->getCalledFunction()->getName(), "__kmpc_global_thread_num"); + EXPECT_FALSE(GTID->getCalledFunction()->doesNotAccessMemory()); + EXPECT_FALSE(GTID->getCalledFunction()->doesNotFreeMemory()); + + CallInst *Cancel = dyn_cast(GTID->getNextNode()); + EXPECT_NE(Cancel, nullptr); + EXPECT_EQ(Cancel->getNumArgOperands(), 3U); + EXPECT_EQ(Cancel->getCalledFunction()->getName(), "__kmpc_cancel"); + EXPECT_FALSE(Cancel->getCalledFunction()->doesNotAccessMemory()); + EXPECT_FALSE(Cancel->getCalledFunction()->doesNotFreeMemory()); + EXPECT_EQ(Cancel->getNumUses(), 1U); + Instruction *CancelBBTI = Cancel->getParent()->getTerminator(); + EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U); + EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U); + EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(), NewIP.getBlock()); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(), + 1U); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), + CBB); + + EXPECT_EQ(cast(Cancel)->getArgOperand(1), GTID); + + OMPBuilder.popFinalizationCB(); + + Builder.CreateUnreachable(); + EXPECT_FALSE(verifyModule(*M)); +} + TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);