Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp =================================================================== --- clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -4375,39 +4375,26 @@ return Result; } -namespace { -/// Dependence kind for RTL. -enum RTLDependenceKindTy { - DepIn = 0x01, - DepInOut = 0x3, - DepMutexInOutSet = 0x4, - DepInOutSet = 0x8, - DepOmpAllMem = 0x80, -}; -/// Fields ids in kmp_depend_info record. -enum RTLDependInfoFieldsTy { BaseAddr, Len, Flags }; -} // namespace - /// Translates internal dependency kind into the runtime kind. static RTLDependenceKindTy translateDependencyKind(OpenMPDependClauseKind K) { RTLDependenceKindTy DepKind; switch (K) { case OMPC_DEPEND_in: - DepKind = DepIn; + DepKind = RTLDependenceKindTy::DepIn; break; // Out and InOut dependencies must use the same code. case OMPC_DEPEND_out: case OMPC_DEPEND_inout: - DepKind = DepInOut; + DepKind = RTLDependenceKindTy::DepInOut; break; case OMPC_DEPEND_mutexinoutset: - DepKind = DepMutexInOutSet; + DepKind = RTLDependenceKindTy::DepMutexInOutSet; break; case OMPC_DEPEND_inoutset: - DepKind = DepInOutSet; + DepKind = RTLDependenceKindTy::DepInOutSet; break; case OMPC_DEPEND_outallmemory: - DepKind = DepOmpAllMem; + DepKind = RTLDependenceKindTy::DepOmpAllMem; break; case OMPC_DEPEND_source: case OMPC_DEPEND_sink: @@ -4455,7 +4442,9 @@ DepObjAddr, KmpDependInfoTy, Base.getBaseInfo(), Base.getTBAAInfo()); // NumDeps = deps[i].base_addr; LValue BaseAddrLVal = CGF.EmitLValueForField( - NumDepsBase, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + NumDepsBase, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); llvm::Value *NumDeps = CGF.EmitLoadOfScalar(BaseAddrLVal, Loc); return std::make_pair(NumDeps, Base); } @@ -4501,18 +4490,24 @@ } // deps[i].base_addr = &; LValue BaseAddrLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); CGF.EmitStoreOfScalar(Addr, BaseAddrLVal); // deps[i].len = sizeof(); LValue LenLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Len)); + Base, *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Len))); CGF.EmitStoreOfScalar(Size, LenLVal); // deps[i].flags = ; RTLDependenceKindTy DepKind = translateDependencyKind(Data.DepKind); LValue FlagsLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Flags)); - CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind), - FlagsLVal); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Flags))); + CGF.EmitStoreOfScalar( + llvm::ConstantInt::get(LLVMFlagsTy, static_cast(DepKind)), + FlagsLVal); if (unsigned *P = Pos.dyn_cast()) { ++(*P); } else { @@ -4788,7 +4783,9 @@ LValue Base = CGF.MakeAddrLValue(DependenciesArray, KmpDependInfoTy); // deps[i].base_addr = NumDependencies; LValue BaseAddrLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); CGF.EmitStoreOfScalar(NumDepsVal, BaseAddrLVal); llvm::PointerUnion Pos; unsigned Idx = 1; @@ -4868,9 +4865,11 @@ // deps[i].flags = NewDepKind; RTLDependenceKindTy DepKind = translateDependencyKind(NewDepKind); LValue FlagsLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Flags)); - CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind), - FlagsLVal); + Base, *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Flags))); + CGF.EmitStoreOfScalar( + llvm::ConstantInt::get(LLVMFlagsTy, static_cast(DepKind)), + FlagsLVal); // Shift the address forward by one element. Address ElementNext = Index: llvm/include/llvm/Frontend/OpenMP/OMPConstants.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -207,6 +207,19 @@ /// Atomic compare operations. Currently OpenMP only supports ==, >, and <. enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX }; +/// Fields ids in kmp_depend_info record. +enum class RTLDependInfoFields { BaseAddr, Len, Flags }; + +/// Dependence kind for RTL. +enum class RTLDependenceKindTy { + DepUnknown = 0x0, + DepIn = 0x01, + DepInOut = 0x3, + DepMutexInOutSet = 0x4, + DepInOutSet = 0x8, + DepOmpAllMem = 0x80, +}; + } // end namespace omp } // end namespace llvm Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -645,6 +645,16 @@ /// \param Loc The location where the taskyield directive was encountered. void createTaskyield(const LocationDescription &Loc); + struct DependData { + omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown; + Type *DepValueType; + Value *DepVal; + explicit DependData() = default; + DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType, + Value *DepVal) + : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {} + }; + /// Generator for `#omp task` /// /// \param Loc The location where the task construct was encountered. @@ -662,7 +672,8 @@ InsertPointTy createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied = true, Value *Final = nullptr, - Value *IfCondition = nullptr); + Value *IfCondition = nullptr, + ArrayRef Dependencies = {}); /// Generator for the taskgroup construct /// Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -92,6 +92,7 @@ __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64) __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr) +__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8) #undef __OMP_STRUCT_TYPE #undef OMP_STRUCT_TYPE Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp =================================================================== --- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1290,7 +1290,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied, Value *Final, Value *IfCondition) { + bool Tied, Value *Final, Value *IfCondition, + ArrayRef Dependencies) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1322,8 +1323,8 @@ OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final, - IfCondition](Function &OutlinedFn) { + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, + Dependencies](Function &OutlinedFn) { // The input IR here looks like the following- // ``` // func @current_fn() { @@ -1433,6 +1434,49 @@ TaskSize); } + Value *DepArrayPtr = nullptr; + if (Dependencies.size()) { + InsertPointTy OldIP = Builder.saveIP(); + Builder.SetInsertPoint( + &OldIP.getBlock()->getParent()->getEntryBlock().back()); + + Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size()); + Value *DepArray = + Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr"); + + unsigned P = 0; + for (DependData *Dep : Dependencies) { + Value *Base = + Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P); + // Store the pointer to the variable + Value *Addr = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::BaseAddr)); + Value *DepValPtr = + Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty()); + Builder.CreateStore(DepValPtr, Addr); + // Store the size of the variable + Value *Size = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::Len)); + Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize( + Dep->DepValueType)), + Size); + // Store the dependency kind + Value *Flags = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::Flags)); + Builder.CreateStore( + ConstantInt::get(Builder.getInt8Ty(), + static_cast(Dep->DepKind)), + Flags); + ++P; + } + + DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy()); + Builder.restoreIP(OldIP); + } + // In the presence of the `if` clause, the following IR is generated: // ... // %data = call @__kmpc_omp_task_alloc(...) @@ -1471,9 +1515,21 @@ Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData}); Builder.SetInsertPoint(ThenTI); } - // Emit the @__kmpc_omp_task runtime call to spawn the task - Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); - Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + + if (Dependencies.size()) { + Function *TaskFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); + Builder.CreateCall( + TaskFn, + {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()), + DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0), + ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))}); + + } else { + // Emit the @__kmpc_omp_task runtime call to spawn the task + Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); + Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + } StaleCI->eraseFromParent(); Index: llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp =================================================================== --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5092,6 +5092,81 @@ EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext())); + OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn, + Type::getInt32Ty(M->getContext()), InDep); + SmallVector DDS; + DDS.push_back(&DDIn); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, + /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `NumDeps` argument + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder + .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + ConstantInt *NumDeps = dyn_cast(TaskAllocCall->getArgOperand(3)); + ASSERT_NE(NumDeps, nullptr); + EXPECT_EQ(NumDeps->getZExtValue(), 1U); + + // Check for the `DepInfo` array argument + BitCastInst *DepArrayPtr = + dyn_cast(TaskAllocCall->getOperand(4)); + ASSERT_NE(DepArrayPtr, nullptr); + AllocaInst *DepArray = dyn_cast(DepArrayPtr->getOperand(0)); + ASSERT_NE(DepArray, nullptr); + Value::user_iterator DepArrayI = DepArray->user_begin(); + EXPECT_EQ(*DepArrayI, DepArrayPtr); + ++DepArrayI; + Value::user_iterator DepInfoI = DepArrayI->user_begin(); + // Check for the `DependKind` flag in the `DepInfo` array + Value *Flag = findStoredValue(*DepInfoI); + ASSERT_NE(Flag, nullptr); + ConstantInt *FlagInt = dyn_cast(Flag); + ASSERT_NE(FlagInt, nullptr); + EXPECT_EQ(FlagInt->getZExtValue(), + static_cast(RTLDependenceKindTy::DepIn)); + ++DepInfoI; + // Check for the size in the `DepInfo` array + Value *Size = findStoredValue(*DepInfoI); + ASSERT_NE(Size, nullptr); + ConstantInt *SizeInt = dyn_cast(Size); + ASSERT_NE(SizeInt, nullptr); + EXPECT_EQ(SizeInt->getZExtValue(), 4U); + ++DepInfoI; + // Check for the variable address in the `DepInfo` array + Value *AddrStored = findStoredValue(*DepInfoI); + ASSERT_NE(AddrStored, nullptr); + PtrToIntInst *AddrInt = dyn_cast(AddrStored); + ASSERT_NE(AddrInt, nullptr); + Value *Addr = AddrInt->getPointerOperand(); + EXPECT_EQ(Addr, InDep); + + ConstantInt *NumDepsNoAlias = + dyn_cast(TaskAllocCall->getArgOperand(5)); + ASSERT_NE(NumDepsNoAlias, nullptr); + EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U); + EXPECT_EQ(TaskAllocCall->getOperand(6), + ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext()))); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M);