diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -188,7 +188,20 @@ let regions = (region SizedRegion<1>:$region); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + oilist( `reduction` `(` + custom( + $reduction_vars, type($reduction_vars), $reductions + ) `)` + | `allocate` `(` + custom( + $allocate_vars, type($allocate_vars), + $allocators_vars, type($allocators_vars) + ) `)` + | `nowait` + ) $region attr-dict + }]; + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -304,35 +304,32 @@ /// reduction-entry-list ::= reduction-entry /// | reduction-entry-list `,` reduction-entry /// reduction-entry ::= symbol-ref `->` ssa-id `:` type -static ParseResult -parseReductionVarList(OpAsmParser &parser, - SmallVectorImpl &symbols, - SmallVectorImpl &operands, - SmallVectorImpl &types) { - if (failed(parser.parseLParen())) - return failure(); - +static ParseResult parseReductionVarList( + OpAsmParser &parser, SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &redcuctionSymbols) { + SmallVector reductionVec; do { - if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || + if (parser.parseAttribute(reductionVec.emplace_back()) || + parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); } while (succeeded(parser.parseOptionalComma())); - return parser.parseRParen(); + SmallVector reductions(reductionVec.begin(), reductionVec.end()); + redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); + return success(); } /// Print Reduction clause -static void printReductionVarList(OpAsmPrinter &p, - Optional reductions, - OperandRange reductionVars) { - p << "reduction("; +static void printReductionVarList(OpAsmPrinter &p, Operation *op, + OperandRange reductionVars, + TypeRange reductionTypes, + Optional reductions) { for (unsigned i = 0, e = reductions->size(); i < e; ++i) { if (i != 0) p << ", "; p << (*reductions)[i] << " -> " << reductionVars[i] << " : " << reductionVars[i].getType(); } - p << ") "; } /// Verifies Reduction Clause @@ -552,7 +549,7 @@ SmallVector allocates, allocators; SmallVector allocateTypes, allocatorTypes; - SmallVector reductionSymbols; + ArrayAttr reductions; SmallVector reductionVars; SmallVector reductionVarTypes; @@ -639,9 +636,10 @@ "proc_bind_val", "proc bind")) return failure(); } else if (clauseKeyword == "reduction") { - if (checkAllowed(reductionClause) || - parseReductionVarList(parser, reductionSymbols, reductionVars, - reductionVarTypes)) + if (checkAllowed(reductionClause) || parser.parseLParen() || + parseReductionVarList(parser, reductionVars, reductionVarTypes, + reductions) || + parser.parseRParen()) return failure(); clauseSegments[pos[reductionClause]] = reductionVars.size(); } else if (clauseKeyword == "nowait") { @@ -746,11 +744,7 @@ if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, parser.getNameLoc(), result.operands))) return failure(); - - SmallVector reductions(reductionSymbols.begin(), - reductionSymbols.end()); - result.addAttribute("reductions", - parser.getBuilder().getArrayAttr(reductions)); + result.addAttribute("reductions", reductions); } // Add linear parameters @@ -805,53 +799,9 @@ } //===----------------------------------------------------------------------===// -// Parser, printer and verifier for SectionsOp +// Verifier for SectionsOp //===----------------------------------------------------------------------===// -/// Parses an OpenMP Sections operation -/// -/// sections ::= `omp.sections` clause-list -/// clause-list ::= clause clause-list | empty -/// clause ::= reduction | allocate | nowait -ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector clauses = {reductionClause, allocateClause, - nowaitClause}; - - SmallVector segments; - - if (failed(parseClauses(parser, result, clauses, segments))) - return failure(); - - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr(segments)); - - // Now parse the body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body)) - return failure(); - return success(); -} - -void SectionsOp::print(OpAsmPrinter &p) { - p << " "; - - if (!reduction_vars().empty()) - printReductionVarList(p, reductions(), reduction_vars()); - - if (!allocate_vars().empty()) { - printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(), - allocate_vars().getTypes(), allocators_vars(), - allocators_vars().getTypes()); - p << ")"; - } - - if (nowait()) - p << "nowait"; - - p << ' '; - p.printRegion(region()); -} - LogicalResult SectionsOp::verify() { if (allocate_vars().size() != allocators_vars().size()) return emitError( @@ -960,8 +910,11 @@ if (auto order = order_val()) p << "order(" << stringifyClauseOrderKind(*order) << ") "; - if (!reduction_vars().empty()) - printReductionVarList(p, reductions(), reduction_vars()); + if (!reduction_vars().empty()) { + printReductionVarList(p << "reduction(", *this, reduction_vars(), + reduction_vars().getTypes(), reductions()); + p << ")"; + } p << ' '; p.printRegion(region(), /*printEntryBlockArgs=*/false); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -793,7 +793,7 @@ // ----- func @omp_sections(%cond : i1) { - // expected-error @below {{if is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections if(%cond) { omp.terminator } @@ -803,7 +803,7 @@ // ----- func @omp_sections() { - // expected-error @below {{num_threads is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections num_threads(10) { omp.terminator } @@ -813,7 +813,7 @@ // ----- func @omp_sections() { - // expected-error @below {{proc_bind is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections proc_bind(close) { omp.terminator } @@ -823,7 +823,7 @@ // ----- func @omp_sections(%data_var : memref, %linear_var : i32) { - // expected-error @below {{linear is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections linear(%data_var = %linear_var : memref) { omp.terminator } @@ -833,7 +833,7 @@ // ----- func @omp_sections() { - // expected-error @below {{schedule is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections schedule(static, none) { omp.terminator } @@ -843,7 +843,7 @@ // ----- func @omp_sections() { - // expected-error @below {{collapse is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections collapse(3) { omp.terminator } @@ -853,7 +853,7 @@ // ----- func @omp_sections() { - // expected-error @below {{ordered is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections ordered(2) { omp.terminator } @@ -863,7 +863,7 @@ // ----- func @omp_sections() { - // expected-error @below {{order is not a valid clause for the omp.sections operation}} + // expected-error @below {{expected '{' to begin a region}} omp.sections order(concurrent) { omp.terminator } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -624,13 +624,13 @@ "omp.sections" (%data_var1, %data_var1) ({ // CHECK: omp.terminator omp.terminator - }) {operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref, memref) -> () + }) {allocate, operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref, memref) -> () // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) "omp.sections" (%redn_var) ({ // CHECK: omp.terminator omp.terminator - }) {operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr) -> () + }) {reduction, operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr) -> () // CHECK: omp.sections nowait { omp.sections nowait {