Index: mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h =================================================================== --- mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -127,6 +127,10 @@ /// OpenMP dialect hasn't been loaded (it is always loaded if there are OpenMP /// operations in the module though). const Dialect *ompDialect; + /// Stack which stores the target block to which a branch a must be added when + /// a terminator is seen. A stack is required to handle nested OpenMP parallel + /// regions. + SmallVector ompContinuationIPStack; /// Mappings between llvm.mlir.global definitions and corresponding globals. DenseMap globalsMapping; Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -397,8 +397,8 @@ llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock(); llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator(); + ompContinuationIPStack.push_back(&continuationIP); - builder.SetInsertPoint(codeGenIPBB); // ParallelOp has only `1` region associated with it. auto ®ion = cast(opInst).getRegion(); for (auto &bb : region) { @@ -413,22 +413,22 @@ for (auto indexedBB : llvm::enumerate(blocks)) { Block *bb = indexedBB.value(); llvm::BasicBlock *curLLVMBB = blockMapping[bb]; - if (bb->isEntryBlock()) + if (bb->isEntryBlock()) { + assert(codeGenIPBBTI->getNumSuccessors() == 1 && + "OpenMPIRBuilder provided entry block has multiple successors"); + assert(codeGenIPBBTI->getSuccessor(0) == &continuationIP && + "ContinuationIP is not the successor of OpenMPIRBuilder " + "provided entry block"); codeGenIPBBTI->setSuccessor(0, curLLVMBB); + } // TODO: Error not returned up the hierarchy if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) return; - - // If this block has the terminator then add a jump to - // continuation bb - for (auto &op : *bb) { - if (isa(op)) { - builder.SetInsertPoint(curLLVMBB); - builder.CreateBr(&continuationIP); - } - } } + + ompContinuationIPStack.pop_back(); + // Finally, after all blocks have been traversed and values mapped, // connect the PHI nodes to the results of preceding blocks. connectPHINodes(region, valueMapping, blockMapping); @@ -504,7 +504,10 @@ ompBuilder->CreateFlush(builder.saveIP()); return success(); }) - .Case([&](omp::TerminatorOp) { return success(); }) + .Case([&](omp::TerminatorOp) { + builder.CreateBr(ompContinuationIPStack.back()); + return success(); + }) .Case( [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) .Default([&](Operation *inst) { Index: mlir/test/Target/openmp-llvm.mlir =================================================================== --- mlir/test/Target/openmp-llvm.mlir +++ mlir/test/Target/openmp-llvm.mlir @@ -214,3 +214,53 @@ // CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]] // CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]] // CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]] + +// CHECK-LABEL: define void @test_omp_parallel_4() +llvm.func @test_omp_parallel_4() -> () { +// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1:.*]] to +// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1]] +// CHECK: call void @__kmpc_barrier +// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1_1:.*]] to +// CHECK: call void @__kmpc_barrier + omp.parallel { + omp.barrier + +// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1_1]] +// CHECK: call void @__kmpc_barrier + omp.parallel { + omp.barrier + omp.terminator + } + + omp.barrier + omp.terminator + } + llvm.return +} + +llvm.func @test_omp_parallel_5() -> () { +// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1:.*]] to +// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1]] +// CHECK: call void @__kmpc_barrier +// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1:.*]] to +// CHECK: call void @__kmpc_barrier + omp.parallel { + omp.barrier + +// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1]] + omp.parallel { +// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1_1:.*]] to +// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1_1]] +// CHECK: call void @__kmpc_barrier + omp.parallel { + omp.barrier + omp.terminator + } + omp.terminator + } + + omp.barrier + omp.terminator + } + llvm.return +}