Index: llvm/include/llvm/Frontend/OpenMP/OMPConstants.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -207,6 +207,11 @@ /// Atomic compare operations. Currently OpenMP only supports ==, >, and <. enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX }; +enum OMPRTLDependenceKindTy { + OMPDepIn = 0x01, + OMPDepInOut = 0x3, +}; + } // end namespace omp } // end namespace llvm Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -640,6 +640,21 @@ /// \param Loc The location where the taskyield directive was encountered. void createTaskyield(const LocationDescription &Loc); + enum OpenMPDependKind { + OMP_DEPEND_in, + OMP_DEPEND_out, + OMP_DEPEND_inout, + OMP_DEPEND_unknown, + }; + struct DependData { + OpenMPDependKind DepKind = OMP_DEPEND_unknown; + Type *DepValueType; + Value *DepVal; + explicit DependData() = default; + DependData(OpenMPDependKind DepKind, Type *DepValueType, Value *DepVal) + : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {} + }; + /// Generator for `#omp task` /// /// \param Loc The location where the task construct was encountered. @@ -657,7 +672,8 @@ InsertPointTy createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied = true, Value *Final = nullptr, - Value *IfCondition = nullptr); + Value *IfCondition = nullptr, + ArrayRef Dependencies = {}); /// Generator for the taskgroup construct /// Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -92,6 +92,7 @@ __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64) __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr) +__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8) #undef __OMP_STRUCT_TYPE #undef OMP_STRUCT_TYPE Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp =================================================================== --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1287,10 +1287,30 @@ emitTaskyieldImpl(Loc); } +/// Translates internal dependency kind into the runtime kind. +static OMPRTLDependenceKindTy +translateDependencyKind(llvm::OpenMPIRBuilder::OpenMPDependKind K) { + OMPRTLDependenceKindTy DepKind; + switch (K) { + case llvm::OpenMPIRBuilder::OMP_DEPEND_in: + DepKind = OMPDepIn; + break; + // Out and InOut dependencies must use the same code. + case llvm::OpenMPIRBuilder::OMP_DEPEND_out: + case llvm::OpenMPIRBuilder::OMP_DEPEND_inout: + DepKind = OMPDepInOut; + break; + default: + llvm_unreachable("Unknown task dependence type"); + } + return DepKind; +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied, Value *Final, Value *IfCondition) { + bool Tied, Value *Final, Value *IfCondition, + ArrayRef Dependencies) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1322,8 +1342,8 @@ OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final, - IfCondition](Function &OutlinedFn) { + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, + Dependencies](Function &OutlinedFn) { // The input IR here looks like the following- // ``` // func @current_fn() { @@ -1471,9 +1491,49 @@ Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData}); Builder.SetInsertPoint(ThenTI); } - // Emit the @__kmpc_omp_task runtime call to spawn the task - Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); - Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + if (Dependencies.size()) { + Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size()); + Value *DepArray = + Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr"); + + unsigned P = 0; + for (DependData *Dep : Dependencies) { + Value *Base = + Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P); + // Store the pointer to the variable + Value *Addr = Builder.CreateStructGEP(DependInfo, Base, 0); + Value *DepValPtr = + Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty()); + Builder.CreateStore(DepValPtr, Addr); + // Store the size of the variable + Value *Size = Builder.CreateStructGEP(DependInfo, Base, 1); + Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize( + Dep->DepValueType)), + Size); + // Store the dependency kind + Value *Flags = Builder.CreateStructGEP(DependInfo, Base, 2); + Builder.CreateStore( + llvm::ConstantInt::get(Builder.getInt8Ty(), + translateDependencyKind(Dep->DepKind)), + Flags); + ++P; + } + + Value *DepArrayPtr = + Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy()); + Function *TaskFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); + Builder.CreateCall( + TaskFn, + {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()), + DepArrayPtr, llvm::ConstantInt::get(Builder.getInt32Ty(), 0), + llvm::ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))}); + + } else { + // Emit the @__kmpc_omp_task runtime call to spawn the task + Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); + Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + } StaleCI->eraseFromParent(); Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp =================================================================== --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5025,6 +5025,81 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext())); + llvm::OpenMPIRBuilder::DependData DDIn(llvm::OpenMPIRBuilder::OMP_DEPEND_in, + Type::getInt32Ty(M->getContext()), + InDep); + SmallVector DDS; + DDS.push_back(&DDIn); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, + /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `NumDeps` argument + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder + .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + ConstantInt *NumDeps = dyn_cast(TaskAllocCall->getArgOperand(3)); + ASSERT_NE(NumDeps, nullptr); + EXPECT_EQ(NumDeps->getZExtValue(), 1U); + + // Check for the `DepInfo` array argument + BitCastInst *DepArrayPtr = + dyn_cast(TaskAllocCall->getOperand(4)); + ASSERT_NE(DepArrayPtr, nullptr); + AllocaInst *DepArray = dyn_cast(DepArrayPtr->getOperand(0)); + ASSERT_NE(DepArray, nullptr); + Value::user_iterator DepArrayI = DepArray->user_begin(); + EXPECT_EQ(*DepArrayI, DepArrayPtr); + ++DepArrayI; + Value::user_iterator DepInfoI = DepArrayI->user_begin(); + // Check for the `DependKind` flag in the `DepInfo` array + Value *Flag = findStoredValue(*DepInfoI); + ASSERT_NE(Flag, nullptr); + ConstantInt *FlagInt = dyn_cast(Flag); + ASSERT_NE(FlagInt, nullptr); + EXPECT_EQ(FlagInt->getZExtValue(), OMPDepIn); + ++DepInfoI; + // Check for the size in the `DepInfo` array + Value *Size = findStoredValue(*DepInfoI); + ASSERT_NE(Size, nullptr); + ConstantInt *SizeInt = dyn_cast(Size); + ASSERT_NE(SizeInt, nullptr); + EXPECT_EQ(SizeInt->getZExtValue(), 4U); + ++DepInfoI; + // Check for the variable address in the `DepInfo` array + Value *AddrStored = findStoredValue(*DepInfoI); + ASSERT_NE(AddrStored, nullptr); + PtrToIntInst *AddrInt = dyn_cast(AddrStored); + ASSERT_NE(AddrInt, nullptr); + Value *Addr = AddrInt->getPointerOperand(); + EXPECT_EQ(Addr, InDep); + + ConstantInt *NumDepsNoAlias = + dyn_cast(TaskAllocCall->getArgOperand(5)); + ASSERT_NE(NumDepsNoAlias, nullptr); + EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U); + EXPECT_EQ(TaskAllocCall->getOperand(6), + ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext()))); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);