diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -626,7 +626,7 @@ /// \param Tied True if the task is tied, false if the task is untied. InsertPointTy createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied = true); + bool Tied = true, Value *finalBool = nullptr); /// Functions used to generate reductions. Such functions take two Values /// representing LHS and RHS of the reduction, respectively, and a reference diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1256,7 +1256,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied) { + bool Tied, Value *finalBool) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1285,7 +1285,7 @@ OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, &Loc, Tied](Function &OutlinedFn) { + OI.PostOutlineCB = [this, &Loc, Tied, finalBool](Function &OutlinedFn) { // The input IR here looks like the following- // ``` // func @current_fn() { @@ -1330,10 +1330,17 @@ Value *ThreadID = getOrCreateThreadID(Ident); // Argument - `flags` - // If task is tied, then (Flags & 1) == 1. - // If task is untied, then (Flags & 1) == 0. + // Task is tied iff (Flags & 1) == 1. + // Task is untied iff (Flags & 1) == 0. + // Task is final iff (Flags & 2) == 2. + // Task is not final iff (Flags & 2) == 0. // TODO: Handle the other flags. Value *Flags = Builder.getInt32(Tied); + if (finalBool) { + Value *finalFlag = Builder.CreateSelect(finalBool, Builder.getInt32(2), + Builder.getInt32(0)); + Flags = Builder.CreateOr(finalFlag, Flags); + } // Argument - `sizeof_kmp_task_t` (TaskSize) // Tasksize refers to the size in bytes of kmp_task_t data structure diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4577,4 +4577,57 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + Builder.SetInsertPoint(BodyBB); + Value *finalBool = Builder.CreateICmp( + CmpInst::Predicate::ICMP_EQ, F->getArg(0), + ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U)); + OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL); + Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB, + /*Tied=*/false, finalBool)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `Tied` argument + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + BinaryOperator *OrInst = + dyn_cast(TaskAllocCall->getArgOperand(2)); + ASSERT_NE(OrInst, nullptr); + EXPECT_EQ(OrInst->getOpcode(), BinaryOperator::BinaryOps::Or); + + // One of the arguments to `or` instruction is the tied flag, which is equal + // to zero. + EXPECT_TRUE(any_of(OrInst->operands(), [](Value *op) { + if (ConstantInt *TiedValue = dyn_cast(op)) + return TiedValue->getSExtValue() == 0; + return false; + })); + + // One of the arguments to `or` instruction is the final condition. + EXPECT_TRUE(any_of(OrInst->operands(), [finalBool](Value *op) { + if (SelectInst *Select = dyn_cast(op)) { + ConstantInt *TrueValue = dyn_cast(Select->getTrueValue()); + ConstantInt *FalseValue = dyn_cast(Select->getFalseValue()); + if (!TrueValue || !FalseValue) + return false; + return Select->getCondition() == finalBool && + TrueValue->getSExtValue() == 2 && FalseValue->getSExtValue() == 0; + } + return false; + })); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + } // namespace