diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -30,8 +30,27 @@ the construct appears. }]; - let parser = [{ return success(); }]; - let printer = [{ p << getOperationName(); }]; + let assemblyFormat = "attr-dict"; +} + +def TaskwaitOp : OpenMP_Op<"taskwait"> { + let summary = "taskwait construct"; + let description = [{ + The taskwait construct specifies a wait on the completion of child tasks + of the current task. + }]; + + let assemblyFormat = "attr-dict"; +} + +def TaskyieldOp : OpenMP_Op<"taskyield"> { + let summary = "taskyield construct"; + let description = [{ + The taskyield construct specifies that the current task can be suspended + in favor of execution of a different task. + }]; + + let assemblyFormat = "attr-dict"; } #endif // OPENMP_OPS diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -83,6 +83,8 @@ virtual LogicalResult convertOperation(Operation &op, llvm::IRBuilder<> &builder); + virtual LogicalResult convertOmpOperation(Operation &op, + llvm::IRBuilder<> &builder); static std::unique_ptr prepareLLVMModule(Operation *m); /// A helper to look up remapped operands in the value remapping table. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -14,6 +14,7 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "DebugTranslation.h" +#include "mlir/ADT/TypeSwitch.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Attributes.h" @@ -306,6 +307,34 @@ } ModuleTranslation::~ModuleTranslation() {} +/// Given an OpenMP MLIR operation, create the corresponding LLVM IR +/// (including OpenMP runtime calls). +LogicalResult +ModuleTranslation::convertOmpOperation(Operation &opInst, + llvm::IRBuilder<> &builder) { + if (!ompBuilder) { + ompBuilder = std::make_unique(*llvmModule); + ompBuilder->initialize(); + } + return mlir::TypeSwitch(&opInst) + .Case([&](omp::BarrierOp) { + ompBuilder->CreateBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); + return success(); + }) + .Case([&](omp::TaskwaitOp) { + ompBuilder->CreateTaskwait(builder.saveIP()); + return success(); + }) + .Case([&](omp::TaskyieldOp) { + ompBuilder->CreateTaskyield(builder.saveIP()); + return success(); + }) + .Default([&](Operation *inst) { + return inst->emitError("unsupported OpenMP operation: ") + << inst->getName(); + }); +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a @@ -415,17 +444,7 @@ } if (opInst.getDialect() == ompDialect) { - if (!ompBuilder) { - ompBuilder = std::make_unique(*llvmModule); - ompBuilder->initialize(); - } - - if (isa(opInst)) { - ompBuilder->CreateBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); - return success(); - } - return opInst.emitError("unsupported OpenMP operation: ") - << opInst.getName(); + return convertOmpOperation(opInst, builder); } return opInst.emitError("unsupported or non-LLVM operation: ") diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -5,3 +5,15 @@ omp.barrier return } + +func @omp_taskwait() -> () { + // CHECK: omp.taskwait + omp.taskwait + return +} + +func @omp_taskyield() -> () { + // CHECK: omp.taskyield + omp.taskyield + return +} diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -1,10 +1,19 @@ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s // CHECK-LABEL: define void @empty() -// CHECK: [[OMP_THREAD:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) -// CHECK-NEXT: call void @__kmpc_barrier(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD]]) -// CHECK-NEXT: ret void llvm.func @empty() { + // CHECK: [[OMP_THREAD:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK-NEXT: call void @__kmpc_barrier(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD]]) omp.barrier + + // CHECK: [[OMP_THREAD1:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK-NEXT: [[RET_VAL:%.*]] = call i32 @__kmpc_omp_taskwait(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD1]]) + omp.taskwait + + // CHECK: [[OMP_THREAD2:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @{{[0-9]+}}) + // CHECK-NEXT: [[RET_VAL:%.*]] = call i32 @__kmpc_omp_taskyield(%struct.ident_t* @{{[0-9]+}}, i32 [[OMP_THREAD2]], i32 0) + omp.taskyield + +// CHECK-NEXT: ret void llvm.return }