diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -34,4 +34,26 @@ let printer = [{ p << getOperationName(); }]; } +def TaskwaitOp : OpenMP_Op<"taskwait"> { + let summary = "taskwait construct"; + let description = [{ + The taskwait construct specifies a wait on the completion of child tasks + of the current task. + }]; + + let parser = [{ return success(); }]; + let printer = [{ p << getOperationName(); }]; +} + +def TaskyieldOp : OpenMP_Op<"taskyield"> { + let summary = "taskyield construct"; + let description = [{ + The taskyield construct specifies that the current task can be suspended + in favor of execution of a different task. + }]; + + let parser = [{ return success(); }]; + let printer = [{ p << getOperationName(); }]; +} + #endif // OPENMP_OPS diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -83,6 +83,8 @@ virtual LogicalResult convertOperation(Operation &op, llvm::IRBuilder<> &builder); + virtual LogicalResult convertOmpOperation(Operation &op, + llvm::IRBuilder<> &builder); static std::unique_ptr prepareLLVMModule(Operation *m); /// A helper to look up remapped operands in the value remapping table. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -289,6 +289,31 @@ } ModuleTranslation::~ModuleTranslation() {} +/// Given a single OpenMP MLIR operation, create the corresponding LLVM IR +/// operation. +LogicalResult +ModuleTranslation::convertOmpOperation(Operation &opInst, + llvm::IRBuilder<> &builder) { + if (!ompBuilder) { + ompBuilder = std::make_unique(*llvmModule); + ompBuilder->initialize(); + } + + if (isa(opInst)) { + ompBuilder->CreateBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); + return success(); + } + if (isa(opInst)) { + ompBuilder->CreateTaskwait(builder.saveIP()); + return success(); + } + if (isa(opInst)) { + ompBuilder->CreateTaskyield(builder.saveIP()); + return success(); + } + return opInst.emitError("unsupported OpenMP operation: ") << opInst.getName(); +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a @@ -388,17 +413,7 @@ } if (opInst.getDialect() == ompDialect) { - if (!ompBuilder) { - ompBuilder = std::make_unique(*llvmModule); - ompBuilder->initialize(); - } - - if (isa(opInst)) { - ompBuilder->CreateBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); - return success(); - } - return opInst.emitError("unsupported OpenMP operation: ") - << opInst.getName(); + return convertOmpOperation(opInst, builder); } return opInst.emitError("unsupported or non-LLVM operation: ") diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -5,3 +5,15 @@ omp.barrier return } + +func @omp_taskwait() -> () { + // CHECK: omp.taskwait + omp.taskwait + return +} + +func @omp_taskyield() -> () { + // CHECK: omp.taskyield + omp.taskyield + return +} diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -3,8 +3,14 @@ // CHECK-LABEL: define void @empty() // CHECK: [[OMP_THREAD:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) // CHECK-NEXT: call void @__kmpc_barrier(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD]]) +// CHECK: [[OMP_THREAD1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) +// CHECK-NEXT: [[RET_VAL:%.*]] = call i32 @__kmpc_omp_taskwait(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD1]]) +// CHECK: [[OMP_THREAD2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) +// CHECK-NEXT: [[RET_VAL:%.*]] = call i32 @__kmpc_omp_taskyield(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD2]], i32 0) // CHECK-NEXT: ret void llvm.func @empty() { omp.barrier + omp.taskwait + omp.taskyield llvm.return }