diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -400,16 +400,21 @@ The optional $nowait elliminates the implicit barrier so the parent task can make progress even if the target task is not yet completed. - TODO: private, map, is_device_ptr, firstprivate, depend, defaultmap, in_reduction + TODO: map, is_device_ptr, depend, defaultmap, in_reduction }]; let arguments = (ins Optional:$if_expr, Optional:$device, Optional:$thread_limit, + Variadic:$private_vars, + Variadic:$firstprivate_vars, UnitAttr:$nowait); let regions = (region AnyRegion:$region); + + let parser = [{ return parseTargetOp(parser, result); }]; + let printer = [{ return printTargetOp(p, *this); }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -192,6 +192,27 @@ p.printRegion(op.getRegion()); } +static void printTargetOp(OpAsmPrinter &p, TargetOp op) { + p << " "; + if (auto ifCond = op.if_expr()) + p << "if(" << ifCond << " : " << ifCond.getType() << ") "; + + if (auto device = op.device()) + p << "device(" << device << " : " << device.getType() << ") "; + + if (auto threads = op.thread_limit()) + p << "thread_limit(" << threads << " : " << threads.getType() << ") "; + + printDataVars(p, op.private_vars(), "private"); + printDataVars(p, op.firstprivate_vars(), "firstprivate"); + + if (op.nowait()) { + p << "nowait "; + } + + p.printRegion(op.getRegion()); +} + //===----------------------------------------------------------------------===// // Parser and printer for Linear Clause //===----------------------------------------------------------------------===// @@ -521,6 +542,8 @@ enum ClauseType { ifClause, numThreadsClause, + deviceClause, + threadLimitClause, privateClause, firstprivateClause, lastprivateClause, @@ -591,6 +614,8 @@ // Containers for storing operands, types and attributes for various clauses std::pair ifCond; std::pair numThreads; + std::pair device; + std::pair threadLimit; SmallVector privates, firstprivates, lastprivates, shareds, copyins; @@ -661,6 +686,18 @@ parser.parseColonType(numThreads.second) || parser.parseRParen()) return failure(); clauseSegments[pos[numThreadsClause]] = 1; + } else if (clauseKeyword == "device") { + if (checkAllowed(deviceClause) || parser.parseLParen() || + parser.parseOperand(device.first) || + parser.parseColonType(device.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[deviceClause]] = 1; + } else if (clauseKeyword == "thread_limit") { + if (checkAllowed(threadLimitClause) || parser.parseLParen() || + parser.parseOperand(threadLimit.first) || + parser.parseColonType(threadLimit.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[threadLimitClause]] = 1; } else if (clauseKeyword == "private") { if (checkAllowed(privateClause) || parseOperandAndTypeList(parser, privates, privateTypes)) @@ -791,6 +828,18 @@ result.operands))) return failure(); + // Add device parameter. + if (done[deviceClause] && clauseSegments[pos[deviceClause]] && + failed( + parser.resolveOperand(device.first, device.second, result.operands))) + return failure(); + + // Add num_threads parameter. + if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && + failed(parser.resolveOperand(threadLimit.first, threadLimit.second, + result.operands))) + return failure(); + // Add private parameters. if (done[privateClause] && clauseSegments[pos[privateClause]] && failed(parser.resolveOperands(privates, privateTypes, @@ -915,6 +964,33 @@ return success(); } +/// Parses a target operation. +/// +/// operation ::= `omp.target` clause-list +/// clause-list ::= clause | clause clause-list +/// clause ::= if | device | private | firstprivate | nowait +/// +static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) { + SmallVector clauses = {ifClause, deviceClause, + threadLimitClause, privateClause, + firstprivateClause, nowaitClause}; + + SmallVector segments; + + if (failed(parseClauses(parser, result, clauses, segments))) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + SmallVector regionArgs; + SmallVector regionArgTypes; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // Parser, printer and verifier for SectionsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -329,14 +329,14 @@ } // CHECK-LABEL: omp_target -func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () { +func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32, %data_var : memref) -> () { - // Test with optional operands; if_expr, device, thread_limit, and nowait. - // CHECK: omp.target - "omp.target"(%if_cond, %device, %num_threads) ({ + // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait. + // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) private({{.*}}) firstprivate({{.*}}) nowait + "omp.target"(%if_cond, %device, %num_threads, %data_var, %data_var) ({ // CHECK: omp.terminator omp.terminator - }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> () + }) {operand_segment_sizes = dense<[1,1,1,1,1]>: vector<5xi32>, nowait } : ( i1, si32, si32, memref, memref ) -> () // CHECK: omp.barrier omp.barrier @@ -344,6 +344,21 @@ return } +// CHECK-LABEL: omp_target_pretty +func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32, %data_var : memref) -> () { + // CHECK: omp.target if({{.*}}) device({{.*}}) firstprivate({{.*}}) + omp.target if(%if_cond : i1) device(%device : si32) thread_limit(%num_threads : si32) firstprivate(%data_var : memref) { + omp.terminator + } + + // CHECK: omp.target if({{.*}}) device({{.*}}) private({{.*}}) nowait + omp.target if(%if_cond : i1) device(%device : si32) thread_limit(%num_threads : si32) private(%data_var : memref) nowait { + omp.terminator + } + + return +} + // CHECK: omp.reduction.declare // CHECK-LABEL: @add_f32 // CHECK: : f32