diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -97,7 +97,17 @@ let builders = [ OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> ]; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)` + | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)` + | `allocate` `(` + custom( + $allocate_vars, type($allocate_vars), + $allocators_vars, type($allocators_vars) + ) `)` + | `proc_bind` `(` custom($proc_bind_val) `)` + ) $region attr-dict + }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -89,35 +89,53 @@ SmallVectorImpl &operandsAllocator, SmallVectorImpl &typesAllocator) { - return parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { - OpAsmParser::OperandType operand; - Type type; - if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); - operandsAllocator.push_back(operand); - typesAllocator.push_back(type); - if (parser.parseArrow()) - return failure(); - if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); + return parser.parseCommaSeparatedList([&]() -> ParseResult { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operandsAllocator.push_back(operand); + typesAllocator.push_back(type); + if (parser.parseArrow()) + return failure(); + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); - operandsAllocate.push_back(operand); - typesAllocate.push_back(type); - return success(); - }); + operandsAllocate.push_back(operand); + typesAllocate.push_back(type); + return success(); + }); } /// Print allocate clause -static void printAllocateAndAllocator(OpAsmPrinter &p, +static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange varsAllocate, - OperandRange varsAllocator) { - p << "allocate("; + TypeRange typesAllocate, + OperandRange varsAllocator, + TypeRange typesAllocator) { for (unsigned i = 0; i < varsAllocate.size(); ++i) { - std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; - p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; - p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; + std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; + p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; + p << varsAllocate[i] << " : " << typesAllocate[i] << separator; + } +} + +ParseResult parseProcBindKind(OpAsmParser &parser, + omp::ClauseProcBindKindAttr &procBindAttr) { + StringRef procBindStr; + if (parser.parseKeyword(&procBindStr)) + return failure(); + if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) { + procBindAttr = + ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal); + return success(); } + return failure(); +} + +void printProcBindKind(OpAsmPrinter &p, Operation *op, + omp::ClauseProcBindKindAttr procBindAttr) { + p << stringifyClauseProcBindKind(procBindAttr.getValue()); } LogicalResult ParallelOp::verify() { @@ -127,24 +145,6 @@ return success(); } -void ParallelOp::print(OpAsmPrinter &p) { - p << " "; - if (auto ifCond = if_expr_var()) - p << "if(" << ifCond << " : " << ifCond.getType() << ") "; - - if (auto threads = num_threads_var()) - p << "num_threads(" << threads << " : " << threads.getType() << ") "; - - if (!allocate_vars().empty()) - printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); - - if (auto bind = proc_bind_val()) - p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; - - p << ' '; - p.printRegion(getRegion()); -} - //===----------------------------------------------------------------------===// // Parser and printer for Linear Clause //===----------------------------------------------------------------------===// @@ -626,9 +626,10 @@ return failure(); clauseSegments[pos[threadLimitClause]] = 1; } else if (clauseKeyword == "allocate") { - if (checkAllowed(allocateClause) || + if (checkAllowed(allocateClause) || parser.parseLParen() || parseAllocateAndAllocator(parser, allocates, allocateTypes, - allocators, allocatorTypes)) + allocators, allocatorTypes) || + parser.parseRParen()) return failure(); clauseSegments[pos[allocateClause]] = allocates.size(); clauseSegments[pos[allocateClause] + 1] = allocators.size(); @@ -803,32 +804,6 @@ return success(); } -/// Parses a parallel operation. -/// -/// operation ::= `omp.parallel` clause-list -/// clause-list ::= clause | clause clause-list -/// clause ::= if | num-threads | allocate | proc-bind -/// -ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector clauses = {ifClause, numThreadsClause, allocateClause, - procBindClause}; - - SmallVector segments; - - if (failed(parseClauses(parser, result, clauses, segments))) - return failure(); - - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr(segments)); - - Region *body = result.addRegion(); - SmallVector regionArgs; - SmallVector regionArgTypes; - if (parser.parseRegion(*body, regionArgs, regionArgTypes)) - return failure(); - return success(); -} - //===----------------------------------------------------------------------===// // Parser, printer and verifier for SectionsOp //===----------------------------------------------------------------------===// @@ -863,8 +838,12 @@ if (!reduction_vars().empty()) printReductionVarList(p, reductions(), reduction_vars()); - if (!allocate_vars().empty()) - printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); + if (!allocate_vars().empty()) { + printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(), + allocate_vars().getTypes(), allocators_vars(), + allocators_vars().getTypes()); + p << ")"; + } if (nowait()) p << "nowait"; diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s func @unknown_clause() { - // expected-error@+1 {{invalid is not a valid clause}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel invalid { } @@ -11,7 +11,7 @@ // ----- func @if_once(%n : i1) { - // expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}} + // expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}} omp.parallel if(%n : i1) if(%n : i1) { } @@ -21,7 +21,7 @@ // ----- func @num_threads_once(%n : si32) { - // expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}} + // expected-error@+1 {{`num_threads` clause can appear at most once in the expansion of the oilist directive}} omp.parallel num_threads(%n : si32) num_threads(%n : si32) { } @@ -31,7 +31,7 @@ // ----- func @nowait_not_allowed(%n : memref) { - // expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel nowait {} return } @@ -39,7 +39,7 @@ // ----- func @linear_not_allowed(%data_var : memref, %linear_var : i32) { - // expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel linear(%data_var = %linear_var : memref) {} return } @@ -47,7 +47,7 @@ // ----- func @schedule_not_allowed() { - // expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel schedule(static) {} return } @@ -55,7 +55,7 @@ // ----- func @collapse_not_allowed() { - // expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel collapse(3) {} return } @@ -63,7 +63,7 @@ // ----- func @order_not_allowed() { - // expected-error@+1 {{order is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel order(concurrent) {} return } @@ -71,14 +71,14 @@ // ----- func @ordered_not_allowed() { - // expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{expected '{' to begin a region}} omp.parallel ordered(2) {} } // ----- func @proc_bind_once() { - // expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}} + // expected-error@+1 {{`proc_bind` clause can appear at most once in the expansion of the oilist directive}} omp.parallel proc_bind(close) proc_bind(spread) { } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -59,7 +59,7 @@ // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref, memref) -> () + }) {num_threads, allocate, operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref, memref) -> () // CHECK: omp.barrier omp.barrier @@ -68,22 +68,22 @@ // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%if_cond, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref, memref) -> () + }) {if, allocate, operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref, memref) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> () + }) {if, num_threads, operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> () omp.terminator - }) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref, memref) -> () + }) {if, num_threads, allocate, operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref, memref) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref, memref) -> () + }) {allocate, operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref, memref) -> () return }