diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -689,9 +689,9 @@ LLVM::ModuleTranslation &moduleTranslation) { using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; LogicalResult bodyGenStatus = success(); - if (taskOp.getIfExpr() || taskOp.getFinalExpr() || taskOp.getUntiedAttr() || - taskOp.getMergeableAttr() || taskOp.getInReductions() || - taskOp.getPriority() || !taskOp.getAllocateVars().empty()) { + if (taskOp.getUntiedAttr() || taskOp.getMergeableAttr() || + taskOp.getInReductions() || taskOp.getPriority() || + !taskOp.getAllocateVars().empty()) { return taskOp.emitError("unhandled clauses for translation to LLVM IR"); } auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { @@ -733,8 +733,9 @@ findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTask( - ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), /*Final*/ nullptr, - /*IfCondition*/ nullptr, dds)); + ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), + moduleTranslation.lookupValue(taskOp.getFinalExpr()), + moduleTranslation.lookupValue(taskOp.getIfExpr()), dds)); return bodyGenStatus; } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2589,3 +2589,77 @@ llvm.return } } + +// ----- + +llvm.func external @foo_before() -> () +llvm.func external @foo() -> () +llvm.func external @foo_after() -> () + +llvm.func @omp_task_final(%boolexpr: i1) { + llvm.call @foo_before() : () -> () + omp.task final(%boolexpr) { + llvm.call @foo() : () -> () + omp.terminator + } + llvm.call @foo_after() : () -> () + llvm.return +} + +// CHECK-LABEL: define void @omp_task_final( +// CHECK-SAME: i1 %[[boolexpr:.+]]) { +// CHECK: call void @foo_before() +// CHECK: br label %[[entry:[^,]+]] +// CHECK: [[entry]]: +// CHECK: br label %[[codeRepl:[^,]+]] +// CHECK: [[codeRepl]]: ; preds = %entry +// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0 +// CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1 +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 0, i64 0, ptr @omp_task_final..omp_par.wrapper) +// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) +// CHECK: br label %[[task_exit:[^,]+]] +// CHECK: [[task_exit]]: +// CHECK: call void @foo_after() +// CHECK: ret void + +// ----- + +llvm.func external @foo_before() -> () +llvm.func external @foo() -> () +llvm.func external @foo_after() -> () + +llvm.func @omp_task_if(%boolexpr: i1) { + llvm.call @foo_before() : () -> () + omp.task if(%boolexpr) { + llvm.call @foo() : () -> () + omp.terminator + } + llvm.call @foo_after() : () -> () + llvm.return +} + +// CHECK-LABEL: define void @omp_task_if( +// CHECK-SAME: i1 %[[boolexpr:.+]]) { +// CHECK: call void @foo_before() +// CHECK: br label %[[entry:[^,]+]] +// CHECK: [[entry]]: +// CHECK: br label %[[codeRepl:[^,]+]] +// CHECK: [[codeRepl]]: +// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}}) +// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, i64 0, ptr @omp_task_if..omp_par.wrapper) +// CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]] +// CHECK: [[true_label]]: +// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) +// CHECK: br label %[[if_else_exit:[^,]+]] +// CHECK: [[false_label:[^,]+]]: ; preds = %codeRepl +// CHECK: call void @__kmpc_omp_task_begin_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) +// CHECK: %{{.+}} = call i32 @omp_task_if..omp_par.wrapper(i32 %[[omp_global_thread_num]]) +// CHECK: call void @__kmpc_omp_task_complete_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]]) +// CHECK: br label %[[if_else_exit]] +// CHECK: [[if_else_exit]]: +// CHECK: br label %[[task_exit:[^,]+]] +// CHECK: [[task_exit]]: +// CHECK: call void @foo_after() +// CHECK: ret void +