diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -254,8 +254,10 @@ if (const auto &ifClause = std::get_if(&clause.u)) { auto &expr = std::get(ifClause->v.t); - ifClauseOperand = fir::getBase( + mlir::Value ifVal = fir::getBase( converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + ifClauseOperand = firOpBuilder.createConvert( + currentLocation, firOpBuilder.getI1Type(), ifVal); } else if (const auto &numThreadsClause = std::get_if( &clause.u)) { diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90 --- a/flang/test/Lower/OpenMP/parallel.f90 +++ b/flang/test/Lower/OpenMP/parallel.f90 @@ -15,8 +15,9 @@ !=============================================================================== !FIRDialect-LABEL: func @_QPparallel_if -subroutine parallel_if(alpha) +subroutine parallel_if(alpha, beta) integer, intent(in) :: alpha + logical, intent(in) :: beta !OMPDialect: omp.parallel if(%{{.*}} : i1) { !$omp parallel if(alpha .le. 0) @@ -46,6 +47,13 @@ !OMPDialect: omp.terminator !$omp end parallel + !OMPDialect: omp.parallel if(%{{.*}} : i1) { + !$omp parallel if(beta) + !FIRDialect: fir.call + call f1() + !OMPDialect: omp.terminator + !$omp end parallel + end subroutine parallel_if !=============================================================================== diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -99,8 +99,8 @@ of the parallel region. }]; - let arguments = (ins Optional:$if_expr_var, - Optional:$num_threads_var, + let arguments = (ins Optional:$if_expr_var, + Optional:$num_threads_var, Variadic:$allocate_vars, Variadic:$allocators_vars, Variadic:$reduction_vars, diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -51,15 +51,15 @@ omp.terminator } -func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { - // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref -> %{{.*}} : memref) +func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i32) -> () { + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({ // test without if condition - // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref -> %{{.*}} : memref) + // CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref, memref) -> () + }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (i32, memref, memref) -> () // CHECK: omp.barrier omp.barrier @@ -71,13 +71,13 @@ }) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref, memref) -> () // test without allocate - // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> () + }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, i32) -> () omp.terminator - }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref, memref) -> () + }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, i32, memref, memref) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref -> %{{.*}} : memref) @@ -88,14 +88,14 @@ return } -func.func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () { +func.func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : i32, %allocator : si32) -> () { // CHECK: omp.parallel omp.parallel { omp.terminator } - // CHECK: omp.parallel num_threads(%{{.*}} : si32) - omp.parallel num_threads(%num_threads : si32) { + // CHECK: omp.parallel num_threads(%{{.*}} : i32) + omp.parallel num_threads(%num_threads : i32) { omp.terminator } @@ -113,8 +113,8 @@ omp.terminator } - // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) proc_bind(close) - omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) proc_bind(close) { + // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) private(%{{.*}} : memref) proc_bind(close) + omp.parallel num_threads(%num_threads : i32) if(%if_cond: i1) proc_bind(close) { omp.terminator } @@ -347,14 +347,14 @@ } // CHECK-LABEL: omp_target -func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () { +func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> () { // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait. // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait "omp.target"(%if_cond, %device, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, si32 ) -> () + }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, i32 ) -> () // CHECK: omp.barrier omp.barrier @@ -363,14 +363,14 @@ } // CHECK-LABEL: omp_target_pretty -func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () { +func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32) -> () { // CHECK: omp.target if({{.*}}) device({{.*}}) omp.target if(%if_cond) device(%device : si32) { omp.terminator } // CHECK: omp.target if({{.*}}) device({{.*}}) nowait - omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : si32) nowait { + omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : i32) nowait { omp.terminator }