diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -69,7 +69,7 @@ p << "omp.parallel"; if (auto ifCond = op.if_expr_var()) - p << " if(" << ifCond << ")"; + p << " if(" << ifCond << " : " << ifCond.getType() << ")"; if (auto threads = op.num_threads_var()) p << " num_threads(" << threads << " : " << threads.getType() << ")"; @@ -124,7 +124,7 @@ /// Note that each clause can only appear once in the clase-list. static ParseResult parseParallelOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType ifCond; + std::pair ifCond; std::pair numThreads; llvm::SmallVector privates; llvm::SmallVector privateTypes; @@ -152,8 +152,8 @@ // Fail if there was already another if condition if (segments[ifClausePos]) return allowedOnce(parser, "if", opName); - if (parser.parseLParen() || parser.parseOperand(ifCond) || - parser.parseRParen()) + if (parser.parseLParen() || parser.parseOperand(ifCond.first) || + parser.parseColonType(ifCond.second) || parser.parseRParen()) return failure(); segments[ifClausePos] = 1; } else if (keyword == "num_threads") { @@ -209,7 +209,7 @@ auto attr = parser.getBuilder().getStringAttr(attrval); result.addAttribute("default_val", attr); } else if (keyword == "proc_bind") { - // fail if there was already another default clause + // fail if there was already another proc_bind clause if (procBind) return allowedOnce(parser, "proc_bind", opName); procBind = true; @@ -228,8 +228,7 @@ // Add if parameter if (segments[ifClausePos]) { - parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), - result.operands); + parser.resolveOperand(ifCond.first, ifCond.second, result.operands); } // Add num_threads parameter diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -454,7 +454,12 @@ // TODO: The various operands of parallel operation are not handled. // Parallel operation is created with some default options for now. llvm::Value *ifCond = nullptr; + if (auto ifExprVar = cast(opInst).if_expr_var()) + ifCond = valueMapping.lookup(ifExprVar); llvm::Value *numThreads = nullptr; + if (auto numThreadsVar = cast(opInst).num_threads_var()) + numThreads = valueMapping.lookup(numThreadsVar); + // TODO: Is the Parallel construct cancellable? bool isCancellable = false; // TODO: Determine the actual alloca insertion point, e.g., the function // entry or the alloca insertion point as provided by the body callback diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -12,7 +12,7 @@ func @if_once(%n : i1) { // expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}} - omp.parallel if(%n) if(%n) { + omp.parallel if(%n : i1) if(%n : i1) { } return diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -99,14 +99,14 @@ // CHECK omp.parallel shared(%{{.*}} : memref) copyin(%{{.*}} : memref, %{{.*}} : memref) omp.parallel shared(%data_var : memref) copyin(%data_var : memref, %data_var : memref) { - omp.parallel if(%if_cond) { + omp.parallel if(%if_cond: i1) { omp.terminator } omp.terminator } // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) proc_bind(close) - omp.parallel num_threads(%num_threads : si32) if(%if_cond) + omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) private(%data_var : memref) proc_bind(close) { omp.terminator } diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir --- a/mlir/test/Target/openmp-llvm.mlir +++ b/mlir/test/Target/openmp-llvm.mlir @@ -78,3 +78,100 @@ // CHECK-LABEL: omp.par.region2: // CHECK: call void @body(i64 43) // CHECK: br label %omp.par.pre_finalize + +// CHECK: define void @test_omp_parallel_num_threads_1(i32 %[[NUM_THREADS_VAR_1:.*]]) +llvm.func @test_omp_parallel_num_threads_1(%arg0: !llvm.i32) -> () { + // CHECK: %[[GTN_NUM_THREADS_VAR_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_1:.*]]) + // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_1]], i32 %[[GTN_NUM_THREADS_VAR_1]], i32 %[[NUM_THREADS_VAR_1]]) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_1:.*]] to {{.*}} + omp.parallel num_threads(%arg0: !llvm.i32) { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_1]] + // CHECK: call void @__kmpc_barrier + +// CHECK: define void @test_omp_parallel_num_threads_2() +llvm.func @test_omp_parallel_num_threads_2() -> () { + %0 = llvm.mlir.constant(4 : index) : !llvm.i32 + // CHECK: %[[GTN_NUM_THREADS_VAR_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_2:.*]]) + // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_2]], i32 %[[GTN_NUM_THREADS_VAR_2]], i32 4) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_2:.*]] to {{.*}} + omp.parallel num_threads(%0: !llvm.i32) { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_2]] + // CHECK: call void @__kmpc_barrier + +// CHECK: define void @test_omp_parallel_num_threads_3() +llvm.func @test_omp_parallel_num_threads_3() -> () { + %0 = llvm.mlir.constant(4 : index) : !llvm.i32 + // CHECK: %[[GTN_NUM_THREADS_VAR_3_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_1:.*]]) + // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_1]], i32 %[[GTN_NUM_THREADS_VAR_3_1]], i32 4) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_1:.*]] to {{.*}} + omp.parallel num_threads(%0: !llvm.i32) { + omp.barrier + omp.terminator + } + %1 = llvm.mlir.constant(8 : index) : !llvm.i32 + // CHECK: %[[GTN_NUM_THREADS_VAR_3_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_2:.*]]) + // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_2]], i32 %[[GTN_NUM_THREADS_VAR_3_2]], i32 8) + // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_2:.*]] to {{.*}} + omp.parallel num_threads(%1: !llvm.i32) { + omp.barrier + omp.terminator + } + + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_2]] + // CHECK: call void @__kmpc_barrier + +// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_1]] + // CHECK: call void @__kmpc_barrier + +// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]]) +llvm.func @test_omp_parallel_if_1(%arg0: !llvm.i32) -> () { + +// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 + %0 = llvm.mlir.constant(0 : index) : !llvm.i32 + %1 = llvm.icmp "slt" %arg0, %0 : !llvm.i32 + +// CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[SI_VAR_IF_1:.*]]) +// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]] +// CHECK: [[IF_COND_TRUE_BLOCK_1]]: +// CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]] +// CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]: +// CHECK: call void {{.*}} @__kmpc_fork_call(%struct.ident_t* @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]] to void +// CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]] +// CHECK: [[OUTLINED_EXIT_IF_1]]: +// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]] +// CHECK: [[OUTLINED_EXIT_IF_2]]: +// CHECK: br label %[[RETURN_BLOCK_IF_1:.*]] +// CHECK: [[IF_COND_FALSE_BLOCK_1]]: +// CHECK: call void @__kmpc_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) +// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]] +// CHECK: call void @__kmpc_end_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) +// CHECK: br label %[[RETURN_BLOCK_IF_1]] + omp.parallel if(%1 : !llvm.i1) { + omp.barrier + omp.terminator + } + +// CHECK: [[RETURN_BLOCK_IF_1]]: +// CHECK: ret void + llvm.return +} + +// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]] + // CHECK: call void @__kmpc_barrier