diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -418,6 +418,9 @@ UnitAttr:$nowait); let regions = (region AnyRegion:$region); + + let parser = [{ return parseTargetOp(parser, result); }]; + let printer = [{ return printTargetOp(p, *this); }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -198,6 +198,24 @@ p.printRegion(op.getRegion()); } +static void printTargetOp(OpAsmPrinter &p, TargetOp op) { + p << " "; + if (auto ifCond = op.if_expr()) + p << "if(" << ifCond << " : " << ifCond.getType() << ") "; + + if (auto device = op.device()) + p << "device(" << device << " : " << device.getType() << ") "; + + if (auto threads = op.thread_limit()) + p << "thread_limit(" << threads << " : " << threads.getType() << ") "; + + if (op.nowait()) { + p << "nowait "; + } + + p.printRegion(op.getRegion()); +} + //===----------------------------------------------------------------------===// // Parser and printer for Linear Clause //===----------------------------------------------------------------------===// @@ -523,6 +541,8 @@ enum ClauseType { ifClause, numThreadsClause, + deviceClause, + threadLimitClause, privateClause, firstprivateClause, lastprivateClause, @@ -611,6 +631,8 @@ // Containers for storing operands, types and attributes for various clauses std::pair ifCond; std::pair numThreads; + std::pair device; + std::pair threadLimit; SmallVector privates, firstprivates, lastprivates, shareds, copyins; @@ -681,6 +703,18 @@ parser.parseColonType(numThreads.second) || parser.parseRParen()) return failure(); clauseSegments[pos[numThreadsClause]] = 1; + } else if (clauseKeyword == "device") { + if (checkAllowed(deviceClause) || parser.parseLParen() || + parser.parseOperand(device.first) || + parser.parseColonType(device.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[deviceClause]] = 1; + } else if (clauseKeyword == "thread_limit") { + if (checkAllowed(threadLimitClause) || parser.parseLParen() || + parser.parseOperand(threadLimit.first) || + parser.parseColonType(threadLimit.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[threadLimitClause]] = 1; } else if (clauseKeyword == "private") { if (checkAllowed(privateClause) || parseOperandAndTypeList(parser, privates, privateTypes)) @@ -812,6 +846,18 @@ result.operands))) return failure(); + // Add device parameter. + if (done[deviceClause] && clauseSegments[pos[deviceClause]] && + failed( + parser.resolveOperand(device.first, device.second, result.operands))) + return failure(); + + // Add thread_limit parameter. + if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && + failed(parser.resolveOperand(threadLimit.first, threadLimit.second, + result.operands))) + return failure(); + // Add private parameters. if (done[privateClause] && clauseSegments[pos[privateClause]] && failed(parser.resolveOperands(privates, privateTypes, @@ -948,6 +994,33 @@ return success(); } +/// Parses a target operation. +/// +/// operation ::= `omp.target` clause-list +/// clause-list ::= clause | clause clause-list +/// clause ::= if | device | thread_limit | nowait +/// +static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) { + SmallVector clauses = {ifClause, deviceClause, threadLimitClause, + nowaitClause}; + + SmallVector segments; + + if (failed(parseClauses(parser, result, clauses, segments))) + return failure(); + + result.addAttribute( + TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + SmallVector regionArgs; + SmallVector regionArgTypes; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // Parser, printer and verifier for SectionsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -330,12 +330,12 @@ // CHECK-LABEL: omp_target func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () { - // Test with optional operands; if_expr, device, thread_limit, and nowait. - // CHECK: omp.target + // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait. + // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait "omp.target"(%if_cond, %device, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> () + }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> () // CHECK: omp.barrier omp.barrier @@ -343,6 +343,21 @@ return } +// CHECK-LABEL: omp_target_pretty +func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () { + // CHECK: omp.target if({{.*}}) device({{.*}}) + omp.target if(%if_cond : i1) device(%device : si32) { + omp.terminator + } + + // CHECK: omp.target if({{.*}}) device({{.*}}) nowait + omp.target if(%if_cond : i1) device(%device : si32) thread_limit(%num_threads : si32) nowait { + omp.terminator + } + + return +} + // CHECK: omp.reduction.declare // CHECK-LABEL: @add_f32 // CHECK: : f32