diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -202,6 +202,8 @@ __OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr) __OMP_RTL(__kmpc_get_hardware_thread_id_in_block, false, Int32, ) __OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr) +__OMP_RTL(__kmpc_fork_call_if, true, Void, IdentPtr, Int32, ParallelTaskPtr, + Int1) __OMP_RTL(__kmpc_omp_taskwait, false, Int32, IdentPtr, Int32) __OMP_RTL(__kmpc_omp_taskyield, false, Int32, IdentPtr, Int32, /* Int */ Int32) __OMP_RTL(__kmpc_push_num_threads, false, Void, IdentPtr, Int32, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -914,34 +914,23 @@ AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr"); AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr"); - // If there is an if condition we actually use the TIDAddr and ZeroAddr in the - // program, otherwise we only need them for modeling purposes to get the - // associated arguments in the outlined function. In the former case, - // initialize the allocas properly, in the latter case, delete them later. - if (IfCondition) { - Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr); - Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr); - } else { - ToBeDeleted.push_back(TIDAddr); - ToBeDeleted.push_back(ZeroAddr); - } + // We only need TIDAddr and ZeroAddr for modeling purposes to get the + // associated arguments in the outlined function, so we delete them later. + ToBeDeleted.push_back(TIDAddr); + ToBeDeleted.push_back(ZeroAddr); // Create an artificial insertion point that will also ensure the blocks we // are about to split are not degenerated. auto *UI = new UnreachableInst(Builder.getContext(), InsertBB); - Instruction *ThenTI = UI, *ElseTI = nullptr; - if (IfCondition) - SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI); - - BasicBlock *ThenBB = ThenTI->getParent(); - BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry"); + BasicBlock *BB = UI->getParent(); + BasicBlock *PRegEntryBB = BB->splitBasicBlock(UI, "omp.par.entry"); BasicBlock *PRegBodyBB = - PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region"); + PRegEntryBB->splitBasicBlock(UI, "omp.par.region"); BasicBlock *PRegPreFiniBB = - PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize"); + PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize"); BasicBlock *PRegExitBB = - PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit"); + PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit"); auto FiniCBWrapper = [&](InsertPointTy IP) { // Hide "open-ended" blocks from the given FiniCB by setting the right jump @@ -998,8 +987,12 @@ BodyGenCB(InnerAllocaIP, CodeGenIP); LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n"); + FunctionCallee RTLFn; + if (IfCondition) + RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if); + else + RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call); - FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call); if (auto *F = dyn_cast(RTLFn.getCallee())) { if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) { llvm::LLVMContext &Ctx = F->getContext(); @@ -1041,6 +1034,8 @@ SmallVector RealArgs; RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs)); + if (IfCondition) + RealArgs.push_back(IfCondition); RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end()); Builder.CreateCall(RTLFn, RealArgs); @@ -1055,35 +1050,7 @@ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin(); Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr); - // If no "if" clause was present we do not need the call created during - // outlining, otherwise we reuse it in the serialized parallel region. - if (!ElseTI) { - CI->eraseFromParent(); - } else { - - // If an "if" clause was present we are now generating the serialized - // version into the "else" branch. - Builder.SetInsertPoint(ElseTI); - - // Build calls __kmpc_serialized_parallel(&Ident, GTid); - Value *SerializedParallelCallArgs[] = {Ident, ThreadID}; - Builder.CreateCall( - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel), - SerializedParallelCallArgs); - - // OutlinedFn(>id, &zero, CapturedStruct); - CI->removeFromParent(); - Builder.Insert(CI); - - // __kmpc_end_serialized_parallel(&Ident, GTid); - Value *EndArgs[] = {Ident, ThreadID}; - Builder.CreateCall( - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel), - EndArgs); - - LLVM_DEBUG(dbgs() << "With serialized parallel region: " - << *Builder.GetInsertBlock()->getParent() << "\n"); - } + CI->eraseFromParent(); for (Instruction *I : ToBeDeleted) I->eraseFromParent(); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -151,33 +151,18 @@ // CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]]) llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () { -// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the -// function, before the condition. Allocas are only emitted by the builder when -// the `if` clause is present. We match specific SSA value names since LLVM -// actually produces those names. -// CHECK: %tid.addr{{.*}} = alloca i32 -// CHECK: %zero.addr{{.*}} = alloca i32 - -// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 %0 = llvm.mlir.constant(0 : index) : i32 %1 = llvm.icmp "slt" %arg0, %0 : i32 +// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 + // CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(ptr @[[SI_VAR_IF_1:.*]]) -// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]] -// CHECK: [[IF_COND_TRUE_BLOCK_1]]: // CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]] // CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]: -// CHECK: call void {{.*}} @__kmpc_fork_call(ptr @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]]) +// CHECK: call void {{.*}} @__kmpc_fork_call_if(ptr @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]], i1 %[[IF_COND_VAR_1]]) // CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]] // CHECK: [[OUTLINED_EXIT_IF_1]]: -// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]] -// CHECK: [[OUTLINED_EXIT_IF_2]]: // CHECK: br label %[[RETURN_BLOCK_IF_1:.*]] -// CHECK: [[IF_COND_FALSE_BLOCK_1]]: -// CHECK: call void @__kmpc_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) -// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]] -// CHECK: call void @__kmpc_end_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) -// CHECK: br label %[[RETURN_BLOCK_IF_1]] omp.parallel if(%1 : i1) { omp.barrier omp.terminator @@ -193,58 +178,6 @@ // ----- -// CHECK-LABEL: @test_nested_alloca_ip -llvm.func @test_nested_alloca_ip(%arg0: i32) -> () { - - // Check that the allocas are emitted by the OpenMPIRBuilder at the top of - // the function, before the condition. Allocas are only emitted by the - // builder when the `if` clause is present. We match specific SSA value names - // since LLVM actually produces those names and ensure they come before the - // "icmp" that is the first operation we emit. - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 0 - %0 = llvm.mlir.constant(0 : index) : i32 - %1 = llvm.icmp "slt" %arg0, %0 : i32 - - omp.parallel if(%1 : i1) { - // The "parallel" operation will be outlined, check the the function is - // produced. Inside that function, further allocas should be placed before - // another "icmp". - // CHECK: define - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 1 - %2 = llvm.mlir.constant(1 : index) : i32 - %3 = llvm.icmp "slt" %arg0, %2 : i32 - - omp.parallel if(%3 : i1) { - // One more nesting level. - // CHECK: define - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 2 - - %4 = llvm.mlir.constant(2 : index) : i32 - %5 = llvm.icmp "slt" %arg0, %4 : i32 - - omp.parallel if(%5 : i1) { - omp.barrier - omp.terminator - } - - omp.barrier - omp.terminator - } - omp.barrier - omp.terminator - } - - llvm.return -} - -// ----- - // CHECK-LABEL: define void @test_omp_parallel_3() llvm.func @test_omp_parallel_3() -> () { // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}}) diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h --- a/openmp/runtime/src/kmp.h +++ b/openmp/runtime/src/kmp.h @@ -3901,6 +3901,8 @@ KMP_EXPORT kmp_int32 __kmpc_ok_to_fork(ident_t *); KMP_EXPORT void __kmpc_fork_call(ident_t *, kmp_int32 nargs, kmpc_micro microtask, ...); +KMP_EXPORT void __kmpc_fork_call_if(ident_t *loc, kmp_int32 nargs, + kmpc_micro microtask, bool cond, void *args); KMP_EXPORT void __kmpc_serialized_parallel(ident_t *, kmp_int32 global_tid); KMP_EXPORT void __kmpc_end_serialized_parallel(ident_t *, kmp_int32 global_tid); diff --git a/openmp/runtime/src/kmp_csupport.cpp b/openmp/runtime/src/kmp_csupport.cpp --- a/openmp/runtime/src/kmp_csupport.cpp +++ b/openmp/runtime/src/kmp_csupport.cpp @@ -330,6 +330,29 @@ #endif // KMP_STATS_ENABLED } +/*! +@ingroup PARALLEL +@param loc source location information +@param microtask pointer to callback routine consisting of outlined parallel +construct +@param cond condition for running in parallel +@param args struct of pointers to shared variables that aren't global + +Perform a fork only if the condition is true. +*/ +void __kmpc_fork_call_if(ident_t *loc, kmp_int32 argc, kmpc_micro microtask, + bool cond, void *args) { + int gtid = __kmp_entry_gtid(); + int zero = 0; + if (cond) + __kmpc_fork_call(loc, argc, microtask, args); + else { + __kmpc_serialized_parallel(loc, gtid); + microtask(>id, &zero, args); + __kmpc_end_serialized_parallel(loc, gtid); + } +} + /*! @ingroup PARALLEL @param loc source location information