diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -701,6 +701,27 @@ return bodyGenStatus; } +/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder. +static LogicalResult +convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + LogicalResult bodyGenStatus = success(); + if (!tgOp.task_reduction_vars().empty() || !tgOp.allocate_vars().empty()) { + return tgOp.emitError("unhandled clauses for translation to LLVM IR"); + } + auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { + builder.restoreIP(codegenIP); + convertOmpOpRegions(tgOp.region(), "omp.taskgroup.region", builder, + moduleTranslation, bodyGenStatus); + }; + InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTaskgroup( + ompLoc, allocaIP, bodyCB)); + return bodyGenStatus; +} + /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, @@ -1406,6 +1427,9 @@ .Case([&](omp::TaskOp op) { return convertOmpTaskOp(op, builder, moduleTranslation); }) + .Case([&](omp::TaskGroupOp op) { + return convertOmpTaskgroupOp(op, builder, moduleTranslation); + }) .Case([](auto op) { // `yield` and `terminator` can be just omitted. The block structure diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1572,6 +1572,30 @@ return } +// CHECK-LABEL: @omp_taskgroup_clauses +func.func @omp_taskgroup_clauses() -> () { + %testmemref = "test.memref"() : () -> (memref) + %testf32 = "test.f32"() : () -> (!llvm.ptr) + // CHECK: omp.taskgroup task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr) allocate(%{{.+}}: memref -> %{{.+}}: memref) + omp.taskgroup allocate(%testmemref : memref -> %testmemref : memref) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) { + // CHECK: omp.task + omp.task { + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.task + omp.task { + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + return +} + // CHECK-LABEL: @omp_taskloop func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm-invalid.mlir @@ -93,3 +93,17 @@ llvm.store %3, %5 : !llvm.ptr llvm.return } + +// ----- + +omp.reduction.declare @add_f32 : f32 +init { +^bb0(%arg: f32): + %0 = llvm.mlir.constant(0.0): f32 + omp.yield (%0 : f32) +} +combiner { +^bb1(%arg0: f32, %arg1: f32): + %1 = llvm.fadd %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2357,3 +2357,98 @@ // CHECK: call void @[[outlined_fn]](ptr %[[task_data]]) // CHECK: ret i32 0 // CHECK: } + +// ----- + +llvm.func @foo() -> () + +llvm.func @omp_taskgroup(%x: i32, %y: i32, %zaddr: !llvm.ptr) { + omp.taskgroup { + llvm.call @foo() : () -> () + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_taskgroup( +// CHECK-SAME: i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]]) +// CHECK: br label %[[entry:[^,]+]] +// CHECK: [[entry]]: +// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: call void @__kmpc_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]]) +// CHECK: br label %[[omp_taskgroup_region:[^,]+]] +// CHECK: [[omp_taskgroup_region]]: +// CHECK: call void @foo() +// CHECK: br label %[[omp_region_cont:[^,]+]] +// CHECK: [[omp_region_cont]]: +// CHECK: br label %[[taskgroup_exit:[^,]+]] +// CHECK: [[taskgroup_exit]]: +// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]]) +// CHECK: ret void + +// ----- + +llvm.func @foo() -> () +llvm.func @bar(i32, i32, !llvm.ptr) -> () + +llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) { + omp.taskgroup { + %c1 = llvm.mlir.constant(1) : i32 + %ptr1 = llvm.alloca %c1 x i8 : (i32) -> !llvm.ptr + omp.task { + llvm.call @foo() : () -> () + omp.terminator + } + omp.task { + llvm.call @bar(%x, %y, %zaddr) : (i32, i32, !llvm.ptr) -> () + omp.terminator + } + llvm.br ^bb1 + ^bb1: + llvm.call @foo() : () -> () + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_taskgroup_task( +// CHECK-SAME: i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]]) +// CHECK: %[[structArg:.+]] = alloca { i32, i32, ptr }, align 8 +// CHECK: br label %[[entry:[^,]+]] +// CHECK: [[entry]]: ; preds = %3 +// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: call void @__kmpc_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]]) +// CHECK: br label %[[omp_taskgroup_region:[^,]+]] +// CHECK: [[omp_taskgroup_region1:.+]]: +// CHECK: call void @foo() +// CHECK: br label %[[omp_region_cont:[^,]+]] +// CHECK: [[omp_taskgroup_region]]: +// CHECK: %{{.+}} = alloca i8, align 1 +// CHECK: br label %[[codeRepl:[^,]+]] +// CHECK: [[codeRepl]]: +// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper) +// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]]) +// CHECK: br label %[[task_exit:[^,]+]] +// CHECK: [[task_exit]]: +// CHECK: br label %[[codeRepl9:[^,]+]] +// CHECK: [[codeRepl9]]: +// CHECK: %[[gep1:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 0 +// CHECK: store i32 %[[x]], ptr %[[gep1]], align 4 +// CHECK: %[[gep2:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 1 +// CHECK: store i32 %[[y]], ptr %[[gep2]], align 4 +// CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2 +// CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8 +// CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper) +// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false) +// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]]) +// CHECK: br label %[[task_exit3:[^,]+]] +// CHECK: [[task_exit3]]: +// CHECK: br label %[[omp_taskgroup_region1]] +// CHECK: [[omp_region_cont]]: +// CHECK: br label %[[taskgroup_exit:[^,]+]] +// CHECK: [[taskgroup_exit]]: +// CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[omp_global_thread_num]]) +// CHECK: ret void +// CHECK: }