diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -626,9 +626,16 @@ /// \param Tied True if the task is tied, false if the task is untied. /// \param Final i1 value which is `true` if the task is final, `false` if the /// task is not final. + /// \param IfCondition i1 value. If it evaluates to `false`, an undeferred + /// task is generated, and the encountering thread must + /// suspend the current task region, for which execution + /// cannot be resumed until execution of the structured + /// block that is associated with the generated task is + /// completed. InsertPointTy createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied = true, Value *Final = nullptr); + bool Tied = true, Value *Final = nullptr, + Value *IfCondition = nullptr); /// Generator for the taskgroup construct /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/MDBuilder.h" @@ -1288,7 +1289,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied, Value *Final) { + bool Tied, Value *Final, Value *IfCondition) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1320,7 +1321,8 @@ OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) { + OI.PostOutlineCB = [this, Ident, Tied, Final, + IfCondition](Function &OutlinedFn) { // The input IR here looks like the following- // ``` // func @current_fn() { @@ -1430,6 +1432,41 @@ TaskSize); } + // In the presence of the `if` clause, the following IR is generated: + // ... + // %data = call @__kmpc_omp_task_alloc(...) + // br i1 %if_condition, label %then, label %else + // then: + // call @__kmpc_omp_task(...) + // br label %exit + // else: + // call @__kmpc_omp_task_begin_if0(...) + // call @wrapper_fn(...) + // call @__kmpc_omp_task_complete_if0(...) + // br label %exit + // exit: + // ... + if (IfCondition) { + // `SplitBlockAndInsertIfThenElse` requires the block to have a + // terminator. + auto *UI = Builder.CreateUnreachable(); + Instruction *ThenTI = UI, *ElseTI = nullptr; + Builder.SetInsertPoint(UI); + SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI); + Builder.SetInsertPoint(ElseTI); + Function *TaskBeginFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0); + Function *TaskCompleteFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0); + Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData}); + if (HasTaskData) + Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData}); + else + Builder.CreateCall(WrapperFunc, {ThreadID}); + Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData}); + Builder.SetInsertPoint(ThenTI); + UI->eraseFromParent(); + } // Emit the @__kmpc_omp_task runtime call to spawn the task Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4920,6 +4920,71 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + Builder.SetInsertPoint(BodyBB); + Value *IfCondition = Builder.CreateICmp( + CmpInst::Predicate::ICMP_EQ, F->getArg(0), + ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U)); + OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL); + Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB, + /*Tied=*/false, /*Final=*/nullptr, + IfCondition)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + + // Check the branching is based on the if condition argument. + BranchInst *IfConditionBranchInst = + dyn_cast(TaskAllocCall->getParent()->getTerminator()); + ASSERT_NE(IfConditionBranchInst, nullptr); + ASSERT_TRUE(IfConditionBranchInst->isConditional()); + EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition); + + // Check that the `__kmpc_omp_task` executes only in the then branch. + CallInst *TaskCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task) + ->user_back()); + ASSERT_NE(TaskCall, nullptr); + EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0)); + + // Check that the OpenMP Runtime Functions specific to `if` clause execute + // only in the else branch. Also check that the function call is between the + // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls. + CallInst *TaskBeginIfCall = dyn_cast( + OMPBuilder + .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0) + ->user_back()); + CallInst *TaskCompleteCall = dyn_cast( + OMPBuilder + .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0) + ->user_back()); + ASSERT_NE(TaskBeginIfCall, nullptr); + ASSERT_NE(TaskCompleteCall, nullptr); + Function *WrapperFunc = + dyn_cast(TaskAllocCall->getArgOperand(5)->stripPointerCasts()); + ASSERT_NE(WrapperFunc, nullptr); + CallInst *WrapperFuncCall = dyn_cast(WrapperFunc->user_back()); + ASSERT_NE(WrapperFuncCall, nullptr); + EXPECT_EQ(TaskBeginIfCall->getParent(), + IfConditionBranchInst->getSuccessor(1)); + EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall); + EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall); +} + TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);