diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -134,6 +134,67 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// 2.8.1 Sections Construct +//===----------------------------------------------------------------------===// + +def SectionOp : OpenMP_Op<"section"> { + let summary = "section directive"; + let description = [{ + A section operation encloses a region which represents one section in a + sections construct. A section op should always be surrounded by an + `omp.sections` operation. + }]; + let regions = (region AnyRegion:$region); + let assemblyFormat = "$region attr-dict"; +} + +def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> { + let summary = "sections construct"; + let description = [{ + The sections construct is a non-iterative worksharing construct that + contains `omp.section` operations. The `omp.section` operations are to be + distributed among and executed by the threads in a team. Each `omp.section` + is executed once by one of the threads in the team in the context of its + implicit task. + + `private_vars`, `firstprivate_vars` and`lastprivate_vars` arguments are + variadic list of operands that specify the data sharing attributes of the + list of values. They are optional. + + Reductions can be performed in a sections construct by specifying reduction + accumulator variables in `reduction_vars` and symbols referring to reduction + declarations in the `reductions` attribute. Each reduction is identified + by the accumulator it uses and accumulators must not be repeated in the same + reduction. The `omp.reduction` operation accepts the accumulator and a + partial value which is considered to be produced by the section for the + given reduction. If multiple values are produced for the same accumulator, + i.e. there are multiple `omp.reduction`s, the last value is taken. The + reduction declaration specifies how to combine the values from each section + into the final value, which is available in the accumulator after all the + sections complete. + + The $allocators_vars and $allocate_vars parameters are a variadic list of values + that specify the memory allocator to be used to obtain storage for private values. + + The `nowait` attribute, when present, signifies that there should be no + implicit barrier at the end of the construct. + }]; + let arguments = (ins Variadic:$private_vars, + Variadic:$firstprivate_vars, + Variadic:$lastprivate_vars, + Variadic:$reduction_vars, + OptionalAttr:$reductions, + Variadic:$allocate_vars, + Variadic:$allocators_vars, + UnitAttr:$nowait); + + let parser = [{ return parseSectionsOp(parser, result); }]; + let printer = [{ return printSectionsOp(p, *this); }]; + let verifier = [{ return verifySectionsOp(*this); }]; + let regions = (region AnyRegion:$region); +} + //===----------------------------------------------------------------------===// // 2.9.2 Workshare Loop Construct //===----------------------------------------------------------------------===// @@ -209,8 +270,7 @@ Variadic:$linear_vars, Variadic:$linear_step_vars, Variadic:$reduction_vars, - OptionalAttr>:$reductions, + OptionalAttr:$reductions, OptionalAttr:$schedule_val, Optional:$schedule_chunk_var, Confined, [IntMinValue<0>]>:$collapse_val, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -53,20 +54,9 @@ } //===----------------------------------------------------------------------===// -// ParallelOp +// Parser and printer for Operand and type list //===----------------------------------------------------------------------===// -void ParallelOp::build(OpBuilder &builder, OperationState &state, - ArrayRef attributes) { - ParallelOp::build( - builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, - /*default_val=*/nullptr, /*private_vars=*/ValueRange(), - /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), - /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), - /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); - state.addAttributes(attributes); -} - /// Parse a list of operands with types. /// /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` @@ -89,9 +79,20 @@ }); } +static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { + p << "("; + llvm::interleaveComma( + operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); + p << ")"; +} + +//===----------------------------------------------------------------------===// +// Parser and printer for Allocate Clause +//===----------------------------------------------------------------------===// + /// Parse an allocate clause with allocators and a list of operands with types. /// -/// operand-and-type-list ::= `(` allocate-operand-list `)` +/// allocate ::= `allocate` `(` allocate-operand-list `)` /// allocate-operand-list :: = allocate-operand | /// allocator-operand `,` allocate-operand-list /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type @@ -122,246 +123,589 @@ }); } -static LogicalResult verifyParallelOp(ParallelOp op) { - if (op.allocate_vars().size() != op.allocators_vars().size()) - return op.emitError( - "expected equal sizes for allocate and allocator variables"); +static void printAllocateAndAllocator(OpAsmPrinter &p, + OperandRange varsAllocate, + OperandRange varsAllocator) { + if (varsAllocate.empty()) + return; + + p << " allocate("; + for (unsigned i = 0; i < varsAllocate.size(); ++i) { + std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; + p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; + p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; + } +} + +//===----------------------------------------------------------------------===// +// Parser and printer for Linear Clause +//===----------------------------------------------------------------------===// + +/// linear ::= `linear` `(` linear-list `)` +/// linear-list := linear-val | linear-val linear-list +/// linear-val := ssa-id-and-type `=` ssa-id-and-type +static ParseResult +parseLinearClause(OpAsmParser &parser, + SmallVectorImpl &vars, + SmallVectorImpl &types, + SmallVectorImpl &stepVars) { + if (parser.parseLParen()) + return failure(); + + do { + OpAsmParser::OperandType var; + Type type; + OpAsmParser::OperandType stepVar; + if (parser.parseOperand(var) || parser.parseEqual() || + parser.parseOperand(stepVar) || parser.parseColonType(type)) + return failure(); + + vars.push_back(var); + types.push_back(type); + stepVars.push_back(stepVar); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) + return failure(); + return success(); } -static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { - if (auto ifCond = op.if_expr_var()) - p << " if(" << ifCond << " : " << ifCond.getType() << ")"; +static void printLinearClause(OpAsmPrinter &p, Operation *op, + OperandRange linearVars, + OperandRange linearStepVars) { + auto linearVarsSize = linearVars.size(); + p << "("; + for (unsigned i = 0; i < linearVarsSize; ++i) { + std::string separator = i == linearVarsSize - 1 ? ")" : ", "; + p << linearVars[i]; + if (linearStepVars.size() > i) + p << " = " << linearStepVars[i]; + p << " : " << linearVars[i].getType() << separator; + } +} - if (auto threads = op.num_threads_var()) - p << " num_threads(" << threads << " : " << threads.getType() << ")"; +//===----------------------------------------------------------------------===// +// Parser and printer for Schedule Clause +//===----------------------------------------------------------------------===// - // Print private, firstprivate, shared and copyin parameters - auto printDataVars = [&p](StringRef name, OperandRange vars) { - if (vars.size()) { - p << " " << name << "("; - for (unsigned i = 0; i < vars.size(); ++i) { - std::string separator = i == vars.size() - 1 ? ")" : ", "; - p << vars[i] << " : " << vars[i].getType() << separator; - } - } - }; +/// schedule ::= `schedule` `(` sched-list `)` +/// sched-list ::= sched-val | sched-val sched-list +/// sched-val ::= sched-with-chunk | sched-wo-chunk +/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? +/// sched-with-chunk-types ::= `static` | `dynamic` | `guided` +/// sched-wo-chunk ::= `auto` | `runtime` +static ParseResult +parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, + Optional &chunkSize) { + if (parser.parseLParen()) + return failure(); - // Print allocator and allocate parameters - auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, - OperandRange varsAllocator) { - if (varsAllocate.empty()) - return; + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); - p << " allocate("; - for (unsigned i = 0; i < varsAllocate.size(); ++i) { - std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; - p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; - p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; + schedule = keyword; + if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { + if (succeeded(parser.parseOptionalEqual())) { + chunkSize = OpAsmParser::OperandType{}; + if (parser.parseOperand(*chunkSize)) + return failure(); + } else { + chunkSize = llvm::NoneType::None; } - }; + } else if (keyword == "auto" || keyword == "runtime") { + chunkSize = llvm::NoneType::None; + } else { + return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; + } - printDataVars("private", op.private_vars()); - printDataVars("firstprivate", op.firstprivate_vars()); - printDataVars("shared", op.shared_vars()); - printDataVars("copyin", op.copyin_vars()); - printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); + if (parser.parseRParen()) + return failure(); - if (auto def = op.default_val()) - p << " default(" << def->drop_front(3) << ")"; + return success(); +} - if (auto bind = op.proc_bind_val()) - p << " proc_bind(" << bind << ")"; +static void printScheduleClause(OpAsmPrinter &p, Operation *op, + StringRef &sched, Value scheduleChunkVar) { + auto schedLower = sched.lower(); + p << "(" << schedLower; + if (scheduleChunkVar) { + p << " = " << scheduleChunkVar; + } + p << ")"; +} - p.printRegion(op.getRegion()); +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for ReductionVarList +//===----------------------------------------------------------------------===// + +/// reduction ::= `reduction` `(` reduction-entry-list `)` +/// reduction-entry-list ::= reduction-entry +/// | reduction-entry-list `,` reduction-entry +/// reduction-entry ::= symbol-ref `->` ssa-id `:` type +static ParseResult +parseReductionVarList(OpAsmParser &parser, Optional &symbols, + SmallVectorImpl &operands, + SmallVectorImpl &types) { + + mlir::SmallVector symbolsVec; + if (failed(parser.parseLParen())) + return failure(); + + do { + if (parser.parseAttribute(symbolsVec.emplace_back()) || + parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + + if (failed(parser.parseRParen())) { + return failure(); + } + + symbols = parser.getBuilder().getArrayAttr(symbolsVec); + return success(); } -/// Emit an error if the same clause is present more than once on an operation. -static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, - StringRef operation) { - return parser.emitError(parser.getNameLoc()) - << " at most one " << clause << " clause can appear on the " - << operation << " operation"; +static void printReductionVarList(OpAsmPrinter &p, + Optional reductions, + OperandRange reduction_vars) { + if (reduction_vars.empty()) + return; + assert(reductions.hasValue() && reductions->size() == reduction_vars.size() && + "size mismatch: reduction symbols and reduction vars"); + p << " reduction("; + for (unsigned i = 0, e = reductions->size(); i < e; ++i) { + if (i != 0) + p << ", "; + p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " + << reduction_vars[i].getType(); + } + p << ")"; } -/// Parses a parallel operation. -/// -/// operation ::= `omp.parallel` clause-list -/// clause-list ::= clause | clause clause-list -/// clause ::= if | numThreads | private | firstprivate | shared | copyin | -/// default | procBind -/// if ::= `if` `(` ssa-id `)` -/// numThreads ::= `num_threads` `(` ssa-id-and-type `)` +static LogicalResult verifyReductionVarList(Operation *op, + Optional reductions, + OperandRange reduction_vars) { + if (reduction_vars.size() != 0) { + if (!reductions || reductions->size() != reduction_vars.size()) { + return op->emitOpError() + << "expected as many reduction symbol references " + "as reduction variables"; + } + } else { + if (reductions) + return op->emitOpError() << "unexpected reduction symbol references"; + return success(); + } + + DenseSet accumulators; + for (auto args : llvm::zip(reduction_vars, *reductions)) { + Value accum = std::get<0>(args); + if (!accumulators.insert(accum).second) { + return op->emitOpError() << "accumulator variable used more than once"; + } + Type varType = accum.getType().cast(); + auto symbolRef = std::get<1>(args).cast(); + auto decl = + SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + if (!decl) { + return op->emitOpError() << "expected symbol reference " << symbolRef + << " to point to a reduction declaration"; + } + + if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { + return op->emitOpError() + << "expected accumulator (" << varType + << ") to be the same type as reduction declaration (" + << decl.getAccumulatorType() << ")"; + } + } + + return success(); +} + +enum ClauseType { + ifClause, + numThreadsClause, + privateClause, + firstprivateClause, + lastprivateClause, + sharedClause, + copyinClause, + allocateClause, + defaultClause, + procBindClause, + reductionClause, + nowaitClause, + linearClause, + scheduleClause, + collapseClause, + orderClause, + orderedClause, + inclusiveClause, + COUNT +}; + +//===----------------------------------------------------------------------===// +// Parser for Clause List +//===----------------------------------------------------------------------===// + +/// Parse a list of clauses. The clauses can appear in any order, but their +/// operand segment indices are in the same order that they are passed in the +/// `clauses` list. The operand segments are added over the prevSegments + +/// clause-list ::= clause clause-list | empty +/// clause ::= if | num-threads | private | firstprivate | lastprivate | +/// shared | copyin | allocate | default | proc-bind | reduction | +/// nowait | linear | schedule | collapse | order | ordered | +/// inclusive +/// if ::= `if` `(` ssa-id-and-type `)` +/// num-threads ::= `num_threads` `(` ssa-id-and-type `)` /// private ::= `private` operand-and-type-list /// firstprivate ::= `firstprivate` operand-and-type-list +/// lastprivate ::= `lastprivate` operand-and-type-list /// shared ::= `shared` operand-and-type-list /// copyin ::= `copyin` operand-and-type-list -/// allocate ::= `allocate` operand-and-type `->` operand-and-type-list +/// allocate ::= `allocate` `(` allocate-operand-list `)` /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) -/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` +/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` +/// reduction ::= `reduction` `(` reduction-entry-list `)` +/// nowait ::= `nowait` +/// linear ::= `linear` `(` linear-list `)` +/// schedule ::= `schedule` `(` sched-list `)` +/// collapse ::= `collapse` `(` ssa-id-and-type `)` +/// order ::= `order` `(` `concurrent` `)` +/// ordered ::= `ordered` `(` ssa-id-and-type `)` +/// inclusive ::= `inclusive` /// /// Note that each clause can only appear once in the clase-list. -static ParseResult parseParallelOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &clauses, + SmallVectorImpl &segments) { + + // Check done[clause] to see if it has been parsed already + llvm::BitVector done(ClauseType::COUNT, false); + + // See pos[clause] to get position of clause in operand segments + SmallVector pos(ClauseType::COUNT, -1); + + // Stores the last parsed clause keyword + StringRef clauseKeyword; + StringRef opName = result.name.getStringRef(); + + // Containers for storing operands, types and attributes for various clauses std::pair ifCond; std::pair numThreads; - SmallVector privates; - SmallVector privateTypes; - SmallVector firstprivates; - SmallVector firstprivateTypes; - SmallVector shareds; - SmallVector sharedTypes; - SmallVector copyins; - SmallVector copyinTypes; - SmallVector allocates; - SmallVector allocateTypes; - SmallVector allocators; - SmallVector allocatorTypes; - std::array segments{0, 0, 0, 0, 0, 0, 0, 0}; - StringRef keyword; - bool defaultVal = false; - bool procBind = false; - - const int ifClausePos = 0; - const int numThreadsClausePos = 1; - const int privateClausePos = 2; - const int firstprivateClausePos = 3; - const int sharedClausePos = 4; - const int copyinClausePos = 5; - const int allocateClausePos = 6; - const int allocatorPos = 7; - const StringRef opName = result.name.getStringRef(); - - while (succeeded(parser.parseOptionalKeyword(&keyword))) { - if (keyword == "if") { - // Fail if there was already another if condition. - if (segments[ifClausePos]) - return allowedOnce(parser, "if", opName); - if (parser.parseLParen() || parser.parseOperand(ifCond.first) || + + SmallVector privates, firstprivates, + lastprivates, shareds, copyins; + SmallVector privateTypes, firstprivateTypes, lastprivateTypes, + sharedTypes, copyinTypes; + + SmallVector allocates, allocators; + SmallVector allocateTypes, allocatorTypes; + + Optional reductionSymbols; + SmallVector reductionVars; + SmallVector reductionVarTypes; + + SmallVector linears; + SmallVector linearTypes; + SmallVector linearSteps; + + SmallString<8> schedule; + Optional scheduleChunkSize; + + // Compute the position of clauses in operand segments + int currPos = 0; + for (auto clause : clauses) { + + // Skip the following clauses - they do not take any position in operand + // segments + if (clause == defaultClause || clause == procBindClause || + clause == nowaitClause || clause == collapseClause || + clause == orderClause || clause == orderedClause || + clause == inclusiveClause) + continue; + + pos[clause] = currPos++; + + // For the following clauses, two positions are reserved in the operand + // segments + if (clause == allocateClause || clause == linearClause) + currPos++; + } + + SmallVector clauseSegments(currPos); + + // Helper function to check if a clause is allowed/repeated or not + auto checkAllowed = [&](ClauseType clause, + bool allowRepeat = false) -> LogicalResult { + if (!llvm::is_contained(clauses, clause)) + return parser.emitError(parser.getCurrentLocation()) + << clauseKeyword << "is not a valid clause for the " << opName + << " operation"; + if (done[clause] && !allowRepeat) + return parser.emitError(parser.getCurrentLocation()) + << "at most one " << clauseKeyword << " clause can appear on the " + << opName << " operation"; + done[clause] = true; + return success(); + }; + + while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { + if (clauseKeyword == "if") { + if (failed(checkAllowed(ifClause)) || parser.parseLParen() || + parser.parseOperand(ifCond.first) || parser.parseColonType(ifCond.second) || parser.parseRParen()) return failure(); - segments[ifClausePos] = 1; - } else if (keyword == "num_threads") { - // Fail if there was already another num_threads clause. - if (segments[numThreadsClausePos]) - return allowedOnce(parser, "num_threads", opName); - if (parser.parseLParen() || parser.parseOperand(numThreads.first) || + clauseSegments[pos[ifClause]] = 1; + } else if (clauseKeyword == "num_threads") { + if (failed(checkAllowed(numThreadsClause)) || parser.parseLParen() || + parser.parseOperand(numThreads.first) || parser.parseColonType(numThreads.second) || parser.parseRParen()) return failure(); - segments[numThreadsClausePos] = 1; - } else if (keyword == "private") { - // Fail if there was already another private clause. - if (segments[privateClausePos]) - return allowedOnce(parser, "private", opName); - if (parseOperandAndTypeList(parser, privates, privateTypes)) + clauseSegments[pos[numThreadsClause]] = 1; + } else if (clauseKeyword == "private") { + if (failed(checkAllowed(privateClause)) || + failed(parseOperandAndTypeList(parser, privates, privateTypes))) + return failure(); + clauseSegments[pos[privateClause]] = privates.size(); + } else if (clauseKeyword == "firstprivate") { + if (failed(checkAllowed(firstprivateClause)) || + failed(parseOperandAndTypeList(parser, firstprivates, + firstprivateTypes))) return failure(); - segments[privateClausePos] = privates.size(); - } else if (keyword == "firstprivate") { - // Fail if there was already another firstprivate clause. - if (segments[firstprivateClausePos]) - return allowedOnce(parser, "firstprivate", opName); - if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + clauseSegments[pos[firstprivateClause]] = firstprivates.size(); + } else if (clauseKeyword == "lastprivate") { + if (failed(checkAllowed(lastprivateClause)) || + failed( + parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))) return failure(); - segments[firstprivateClausePos] = firstprivates.size(); - } else if (keyword == "shared") { - // Fail if there was already another shared clause. - if (segments[sharedClausePos]) - return allowedOnce(parser, "shared", opName); - if (parseOperandAndTypeList(parser, shareds, sharedTypes)) + clauseSegments[pos[lastprivateClause]] = lastprivates.size(); + } else if (clauseKeyword == "shared") { + if (failed(checkAllowed(sharedClause)) || + failed(parseOperandAndTypeList(parser, shareds, sharedTypes))) return failure(); - segments[sharedClausePos] = shareds.size(); - } else if (keyword == "copyin") { - // Fail if there was already another copyin clause. - if (segments[copyinClausePos]) - return allowedOnce(parser, "copyin", opName); - if (parseOperandAndTypeList(parser, copyins, copyinTypes)) + clauseSegments[pos[sharedClause]] = shareds.size(); + } else if (clauseKeyword == "copyin") { + if (failed(checkAllowed(copyinClause)) || + failed(parseOperandAndTypeList(parser, copyins, copyinTypes))) return failure(); - segments[copyinClausePos] = copyins.size(); - } else if (keyword == "allocate") { - // Fail if there was already another allocate clause. - if (segments[allocateClausePos]) - return allowedOnce(parser, "allocate", opName); - if (parseAllocateAndAllocator(parser, allocates, allocateTypes, - allocators, allocatorTypes)) + clauseSegments[pos[copyinClause]] = copyins.size(); + } else if (clauseKeyword == "allocate") { + if (failed(checkAllowed(allocateClause)) || + failed(parseAllocateAndAllocator(parser, allocates, allocateTypes, + allocators, allocatorTypes))) return failure(); - segments[allocateClausePos] = allocates.size(); - segments[allocatorPos] = allocators.size(); - } else if (keyword == "default") { - // Fail if there was already another default clause. - if (defaultVal) - return allowedOnce(parser, "default", opName); - defaultVal = true; + clauseSegments[pos[allocateClause]] = allocates.size(); + clauseSegments[pos[allocateClause] + 1] = allocators.size(); + } else if (clauseKeyword == "default") { StringRef defval; - if (parser.parseLParen() || parser.parseKeyword(&defval) || - parser.parseRParen()) + if (failed(checkAllowed(defaultClause)) || failed(parser.parseLParen()) || + failed(parser.parseKeyword(&defval)) || failed(parser.parseRParen())) return failure(); // The def prefix is required for the attribute as "private" is a keyword // in C++. auto attr = parser.getBuilder().getStringAttr("def" + defval); result.addAttribute("default_val", attr); - } else if (keyword == "proc_bind") { - // Fail if there was already another proc_bind clause. - if (procBind) - return allowedOnce(parser, "proc_bind", opName); - procBind = true; + } else if (clauseKeyword == "proc_bind") { StringRef bind; - if (parser.parseLParen() || parser.parseKeyword(&bind) || - parser.parseRParen()) + if (failed(checkAllowed(procBindClause)) || + failed(parser.parseLParen()) || failed(parser.parseKeyword(&bind)) || + failed(parser.parseRParen())) return failure(); auto attr = parser.getBuilder().getStringAttr(bind); result.addAttribute("proc_bind_val", attr); + } else if (clauseKeyword == "reduction") { + if (failed(checkAllowed(reductionClause)) || + failed(parseReductionVarList(parser, reductionSymbols, reductionVars, + reductionVarTypes))) + return failure(); + clauseSegments[pos[reductionClause]] = reductionVars.size(); + } else if (clauseKeyword == "nowait") { + if (failed(checkAllowed(nowaitClause))) + return failure(); + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("nowait", attr); + } else if (clauseKeyword == "linear") { + if (failed(checkAllowed(linearClause)) || + failed(parseLinearClause(parser, linears, linearTypes, linearSteps))) + return failure(); + clauseSegments[pos[linearClause]] = linears.size(); + clauseSegments[pos[linearClause] + 1] = linearSteps.size(); + } else if (clauseKeyword == "schedule") { + if (failed(checkAllowed(scheduleClause)) || + parseScheduleClause(parser, schedule, scheduleChunkSize)) + return failure(); + if (scheduleChunkSize) { + clauseSegments[pos[scheduleClause]] = 1; + } + } else if (clauseKeyword == "collapse") { + auto type = parser.getBuilder().getI64Type(); + mlir::IntegerAttr attr; + if (failed(checkAllowed(collapseClause)) || parser.parseLParen() || + parser.parseAttribute(attr, type) || parser.parseRParen()) + return failure(); + result.addAttribute("collapse_val", attr); + } else if (clauseKeyword == "ordered") { + mlir::IntegerAttr attr; + if (failed(checkAllowed(orderedClause))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + auto type = parser.getBuilder().getI64Type(); + if (failed(parser.parseAttribute(attr, type)) || + failed(parser.parseRParen())) + return failure(); + } else { + // Use 0 to represent no ordered parameter was specified + attr = parser.getBuilder().getI64IntegerAttr(0); + } + result.addAttribute("ordered_val", attr); + } else if (clauseKeyword == "order") { + StringRef order; + if (failed(checkAllowed(orderClause)) || parser.parseLParen() || + parser.parseKeyword(&order) || parser.parseRParen()) + return failure(); + auto attr = parser.getBuilder().getStringAttr(order); + result.addAttribute("order", attr); + } else if (clauseKeyword == "inclusive") { + if (failed(checkAllowed(inclusiveClause))) + return failure(); + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("inclusive", attr); } else { return parser.emitError(parser.getNameLoc()) - << keyword << " is not a valid clause for the " << opName - << " operation"; + << clauseKeyword << " is not a valid clause"; } } // Add if parameter. - if (segments[ifClausePos] && - parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) + if (done[ifClause] && clauseSegments[pos[ifClause]] && + failed( + parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) return failure(); // Add num_threads parameter. - if (segments[numThreadsClausePos] && - parser.resolveOperand(numThreads.first, numThreads.second, - result.operands)) + if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && + failed(parser.resolveOperand(numThreads.first, numThreads.second, + result.operands))) return failure(); // Add private parameters. - if (segments[privateClausePos] && - parser.resolveOperands(privates, privateTypes, privates[0].location, - result.operands)) + if (done[privateClause] && clauseSegments[pos[privateClause]] && + failed(parser.resolveOperands(privates, privateTypes, + privates[0].location, result.operands))) return failure(); // Add firstprivate parameters. - if (segments[firstprivateClausePos] && - parser.resolveOperands(firstprivates, firstprivateTypes, - firstprivates[0].location, result.operands)) + if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && + failed(parser.resolveOperands(firstprivates, firstprivateTypes, + firstprivates[0].location, + result.operands))) + return failure(); + + // Add lastprivate parameters. + if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && + parser.resolveOperands(lastprivates, lastprivateTypes, + lastprivates[0].location, result.operands)) return failure(); // Add shared parameters. - if (segments[sharedClausePos] && - parser.resolveOperands(shareds, sharedTypes, shareds[0].location, - result.operands)) + if (done[sharedClause] && clauseSegments[pos[sharedClause]] && + failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, + result.operands))) return failure(); // Add copyin parameters. - if (segments[copyinClausePos] && - parser.resolveOperands(copyins, copyinTypes, copyins[0].location, - result.operands)) + if (done[copyinClause] && clauseSegments[pos[copyinClause]] && + failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, + result.operands))) return failure(); // Add allocate parameters. - if (segments[allocateClausePos] && - parser.resolveOperands(allocates, allocateTypes, allocates[0].location, - result.operands)) + if (done[allocateClause] && clauseSegments[pos[allocateClause]] && + failed(parser.resolveOperands(allocates, allocateTypes, + allocates[0].location, result.operands))) return failure(); // Add allocator parameters. - if (segments[allocatorPos] && - parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, - result.operands)) + if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && + failed(parser.resolveOperands(allocators, allocatorTypes, + allocators[0].location, result.operands))) + return failure(); + + // Add rection parameters and symbols + if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { + if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, + parser.getNameLoc(), result.operands))) { + return failure(); + } + result.addAttribute("reductions", *reductionSymbols); + } + + // Add linear parameters + if (done[linearClause] && clauseSegments[pos[linearClause]]) { + auto linearStepType = parser.getBuilder().getI32Type(); + SmallVector linearStepTypes(linearSteps.size(), linearStepType); + if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, + result.operands)) || + failed(parser.resolveOperands(linearSteps, linearStepTypes, + linearSteps[0].location, + result.operands))) + return failure(); + } + + // Add schedule parameters + if (done[scheduleClause] && !schedule.empty()) { + schedule[0] = llvm::toUpper(schedule[0]); + auto attr = parser.getBuilder().getStringAttr(schedule); + result.addAttribute("schedule_val", attr); + if (scheduleChunkSize) { + auto chunkSizeType = parser.getBuilder().getI32Type(); + parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); + } + } + + segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder, parser, printer and verifier for ParallelOp +//===----------------------------------------------------------------------===// + +void ParallelOp::build(OpBuilder &builder, OperationState &state, + ArrayRef attributes) { + ParallelOp::build( + builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, + /*default_val=*/nullptr, /*private_vars=*/ValueRange(), + /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), + /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), + /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); + state.addAttributes(attributes); +} + +/// Parses a parallel operation. +/// +/// operation ::= `omp.parallel` clause-list +/// clause-list ::= clause | clause clause-list +/// clause ::= if | num-threads | private | firstprivate | shared | copyin | +/// allocate | default | proc-bind +/// +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + SmallVector clauses = { + ifClause, numThreadsClause, privateClause, + firstprivateClause, sharedClause, copyinClause, + allocateClause, defaultClause, procBindClause}; + + SmallVector segments; + + if (failed(parseClauses(parser, result, clauses, segments))) return failure(); result.addAttribute("operand_segment_sizes", @@ -375,114 +719,220 @@ return success(); } -/// linear ::= `linear` `(` linear-list `)` -/// linear-list := linear-val | linear-val linear-list -/// linear-val := ssa-id-and-type `=` ssa-id-and-type -static ParseResult -parseLinearClause(OpAsmParser &parser, - SmallVectorImpl &vars, - SmallVectorImpl &types, - SmallVectorImpl &stepVars) { - if (parser.parseLParen()) - return failure(); +static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { + if (auto ifCond = op.if_expr_var()) + p << " if(" << ifCond << " : " << ifCond.getType() << ")"; - do { - OpAsmParser::OperandType var; - Type type; - OpAsmParser::OperandType stepVar; - if (parser.parseOperand(var) || parser.parseEqual() || - parser.parseOperand(stepVar) || parser.parseColonType(type)) - return failure(); + if (auto threads = op.num_threads_var()) + p << " num_threads(" << threads << " : " << threads.getType() << ")"; - vars.push_back(var); - types.push_back(type); - stepVars.push_back(stepVar); - } while (succeeded(parser.parseOptionalComma())); + // Print private, firstprivate, shared and copyin parameters + auto printDataVars = [&p](StringRef name, OperandRange vars) { + if (vars.size()) { + p << " " << name; + printOperandAndTypeList(p, vars); + } + }; - if (parser.parseRParen()) - return failure(); + printDataVars("private", op.private_vars()); + printDataVars("firstprivate", op.firstprivate_vars()); + printDataVars("shared", op.shared_vars()); + printDataVars("copyin", op.copyin_vars()); + printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); + + if (auto def = op.default_val()) + p << " default(" << def->drop_front(3) << ")"; + + if (auto bind = op.proc_bind_val()) + p << " proc_bind(" << bind << ")"; + p.printRegion(op.getRegion()); +} + +static LogicalResult verifyParallelOp(ParallelOp op) { + if (op.allocate_vars().size() != op.allocators_vars().size()) + return op.emitError( + "expected equal sizes for allocate and allocator variables"); return success(); } -/// schedule ::= `schedule` `(` sched-list `)` -/// sched-list ::= sched-val | sched-val sched-list -/// sched-val ::= sched-with-chunk | sched-wo-chunk -/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? -/// sched-with-chunk-types ::= `static` | `dynamic` | `guided` -/// sched-wo-chunk ::= `auto` | `runtime` -static ParseResult -parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, - Optional &chunkSize) { - if (parser.parseLParen()) +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for SectionsOp +//===----------------------------------------------------------------------===// + +/// Parses an OpenMP Sections operation +/// +/// sections ::= `omp.sections` clause-list +/// clause-list ::= clause clause-list | empty +/// clause ::= private | firstprivate | lastprivate | reduction | allocate | +/// nowait +static ParseResult parseSectionsOp(OpAsmParser &parser, + OperationState &result) { + + SmallVector clauses = {privateClause, firstprivateClause, + lastprivateClause, reductionClause, + allocateClause, nowaitClause}; + + SmallVector segments; + + if (failed(parseClauses(parser, result, clauses, segments))) return failure(); - StringRef keyword; - if (parser.parseKeyword(&keyword)) + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + // Now parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body)) return failure(); + return success(); +} - schedule = keyword; - if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { - if (succeeded(parser.parseOptionalEqual())) { - chunkSize = OpAsmParser::OperandType{}; - if (parser.parseOperand(*chunkSize)) - return failure(); - } else { - chunkSize = llvm::NoneType::None; - } - } else if (keyword == "auto" || keyword == "runtime") { - chunkSize = llvm::NoneType::None; - } else { - return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; +static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { + // Print private, firstprivate, shared and copyin parameters + auto printDataVars = [&p](StringRef name, OperandRange vars) { + if (vars.empty()) + return; + p << " " << name; + printOperandAndTypeList(p, vars); + }; + + printDataVars("private", op.private_vars()); + printDataVars("firstprivate", op.firstprivate_vars()); + printDataVars("lastprivate", op.lastprivate_vars()); + + printReductionVarList(p, op.reductions(), op.reduction_vars()); + + printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); + + if (op.nowait()) + p << " nowait"; + + p.printRegion(op.region()); +} + +static LogicalResult verifySectionsOp(SectionsOp op) { + + // A list item may not appear in more than one clause on the same directive, + // except that it may be specified in both firstprivate and lastprivate + // clauses. + for (auto var : op.private_vars()) { + if (std::find(op.firstprivate_vars().begin(), op.firstprivate_vars().end(), + var) != op.firstprivate_vars().end()) + return op.emitOpError() + << "operand used in both private and firstprivate clauses"; + if (std::find(op.lastprivate_vars().begin(), op.lastprivate_vars().end(), + var) != op.lastprivate_vars().end()) + return op.emitOpError() + << "operand used in both private and lastprivate clauses"; } - if (parser.parseRParen()) - return failure(); + if (op.allocate_vars().size() != op.allocators_vars().size()) + return op.emitError( + "expected equal sizes for allocate and allocator variables"); - return success(); + if (!op.region().hasOneBlock()) + return op.emitOpError() << "expected exactly one block in region"; + for (auto &inst : *op.region().begin()) { + if (!(isa(inst) || isa(inst))) + op.emitOpError() + << "expected omp.section op or terminator op inside region"; + } + + return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); } -/// reduction-init ::= `reduction` `(` reduction-entry-list `)` -/// reduction-entry-list ::= reduction-entry -/// | reduction-entry-list `,` reduction-entry -/// reduction-entry ::= symbol-ref `->` ssa-id `:` type -static ParseResult -parseReductionVarList(OpAsmParser &parser, - SmallVectorImpl &symbols, - SmallVectorImpl &operands, - SmallVectorImpl &types) { - if (failed(parser.parseLParen())) - return failure(); +//===----------------------------------------------------------------------===// +// Builders, parser, printer and verifier for WsLoopOp +//===----------------------------------------------------------------------===// + +void WsLoopOp::build(OpBuilder &builder, OperationState &state, + ValueRange lowerBound, ValueRange upperBound, + ValueRange step, ArrayRef attributes) { + build(builder, state, TypeRange(), lowerBound, upperBound, step, + /*private_vars=*/ValueRange(), + /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), + /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), + /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, + /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, + /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, + /*inclusive=*/nullptr, /*buildBody=*/false); + state.addAttributes(attributes); +} + +void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, + ValueRange operands, ArrayRef attributes) { + state.addOperands(operands); + state.addAttributes(attributes); + (void)state.addRegion(); + assert(resultTypes.empty() && "mismatched number of return types"); + state.addTypes(resultTypes); +} + +void WsLoopOp::build(OpBuilder &builder, OperationState &result, + TypeRange typeRange, ValueRange lowerBounds, + ValueRange upperBounds, ValueRange steps, + ValueRange privateVars, ValueRange firstprivateVars, + ValueRange lastprivateVars, ValueRange linearVars, + ValueRange linearStepVars, ValueRange reductionVars, + StringAttr scheduleVal, Value scheduleChunkVar, + IntegerAttr collapseVal, UnitAttr nowait, + IntegerAttr orderedVal, StringAttr orderVal, + UnitAttr inclusive, bool buildBody) { + result.addOperands(lowerBounds); + result.addOperands(upperBounds); + result.addOperands(steps); + result.addOperands(privateVars); + result.addOperands(firstprivateVars); + result.addOperands(linearVars); + result.addOperands(linearStepVars); + if (scheduleChunkVar) + result.addOperands(scheduleChunkVar); + + if (scheduleVal) + result.addAttribute("schedule_val", scheduleVal); + if (collapseVal) + result.addAttribute("collapse_val", collapseVal); + if (nowait) + result.addAttribute("nowait", nowait); + if (orderedVal) + result.addAttribute("ordered_val", orderedVal); + if (orderVal) + result.addAttribute("order", orderVal); + if (inclusive) + result.addAttribute("inclusive", inclusive); + result.addAttribute( + WsLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast(lowerBounds.size()), + static_cast(upperBounds.size()), + static_cast(steps.size()), + static_cast(privateVars.size()), + static_cast(firstprivateVars.size()), + static_cast(lastprivateVars.size()), + static_cast(linearVars.size()), + static_cast(linearStepVars.size()), + static_cast(reductionVars.size()), + static_cast(scheduleChunkVar != nullptr ? 1 : 0)})); - do { - if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || - parser.parseOperand(operands.emplace_back()) || - parser.parseColonType(types.emplace_back())) - return failure(); - } while (succeeded(parser.parseOptionalComma())); - return parser.parseRParen(); + Region *bodyRegion = result.addRegion(); + if (buildBody) { + OpBuilder::InsertionGuard guard(builder); + unsigned numIVs = steps.size(); + SmallVector argTypes(numIVs, steps.getType().front()); + builder.createBlock(bodyRegion, {}, argTypes); + } } /// Parses an OpenMP Workshare Loop operation /// -/// operation ::= `omp.wsloop` loop-control clause-list +/// wsloop ::= `omp.wsloop` loop-control clause-list /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps /// steps := `step` `(`ssa-id-list`)` -/// clause-list ::= clause | empty | clause-list +/// clause-list ::= clause clause-list | empty /// clause ::= private | firstprivate | lastprivate | linear | schedule | -// collapse | nowait | ordered | order | inclusive -/// private ::= `private` `(` ssa-id-and-type-list `)` -/// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` -/// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` -/// linear ::= `linear` `(` linear-list `)` -/// schedule ::= `schedule` `(` sched-list `)` -/// collapse ::= `collapse` `(` ssa-id-and-type `)` -/// nowait ::= `nowait` -/// ordered ::= `ordered` `(` ssa-id-and-type `)` -/// order ::= `order` `(` `concurrent` `)` -/// inclusive ::= `inclusive` -/// +// collapse | nowait | ordered | order | inclusive | reduction static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { Type loopVarType; int numIVs; @@ -518,162 +968,15 @@ parser.resolveOperands(steps, loopVarType, result.operands)) return failure(); - SmallVector privates; - SmallVector privateTypes; - SmallVector firstprivates; - SmallVector firstprivateTypes; - SmallVector lastprivates; - SmallVector lastprivateTypes; - SmallVector linears; - SmallVector linearTypes; - SmallVector linearSteps; - SmallVector reductionSymbols; - SmallVector reductionVars; - SmallVector reductionVarTypes; - SmallString<8> schedule; - Optional scheduleChunkSize; - - const StringRef opName = result.name.getStringRef(); - StringRef keyword; - - enum SegmentPos { - lbPos = 0, - ubPos, - stepPos, - privateClausePos, - firstprivateClausePos, - lastprivateClausePos, - linearClausePos, - linearStepPos, - reductionVarPos, - scheduleClausePos, - }; - std::array segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0}; - - while (succeeded(parser.parseOptionalKeyword(&keyword))) { - if (keyword == "private") { - if (segments[privateClausePos]) - return allowedOnce(parser, "private", opName); - if (parseOperandAndTypeList(parser, privates, privateTypes)) - return failure(); - segments[privateClausePos] = privates.size(); - } else if (keyword == "firstprivate") { - // fail if there was already another firstprivate clause - if (segments[firstprivateClausePos]) - return allowedOnce(parser, "firstprivate", opName); - if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) - return failure(); - segments[firstprivateClausePos] = firstprivates.size(); - } else if (keyword == "lastprivate") { - // fail if there was already another shared clause - if (segments[lastprivateClausePos]) - return allowedOnce(parser, "lastprivate", opName); - if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) - return failure(); - segments[lastprivateClausePos] = lastprivates.size(); - } else if (keyword == "linear") { - // fail if there was already another linear clause - if (segments[linearClausePos]) - return allowedOnce(parser, "linear", opName); - if (parseLinearClause(parser, linears, linearTypes, linearSteps)) - return failure(); - segments[linearClausePos] = linears.size(); - segments[linearStepPos] = linearSteps.size(); - } else if (keyword == "schedule") { - if (!schedule.empty()) - return allowedOnce(parser, "schedule", opName); - if (parseScheduleClause(parser, schedule, scheduleChunkSize)) - return failure(); - if (scheduleChunkSize) { - segments[scheduleClausePos] = 1; - } - } else if (keyword == "collapse") { - auto type = parser.getBuilder().getI64Type(); - mlir::IntegerAttr attr; - if (parser.parseLParen() || parser.parseAttribute(attr, type) || - parser.parseRParen()) - return failure(); - result.addAttribute("collapse_val", attr); - } else if (keyword == "nowait") { - auto attr = UnitAttr::get(parser.getContext()); - result.addAttribute("nowait", attr); - } else if (keyword == "ordered") { - mlir::IntegerAttr attr; - if (succeeded(parser.parseOptionalLParen())) { - auto type = parser.getBuilder().getI64Type(); - if (parser.parseAttribute(attr, type)) - return failure(); - if (parser.parseRParen()) - return failure(); - } else { - // Use 0 to represent no ordered parameter was specified - attr = parser.getBuilder().getI64IntegerAttr(0); - } - result.addAttribute("ordered_val", attr); - } else if (keyword == "order") { - StringRef order; - if (parser.parseLParen() || parser.parseKeyword(&order) || - parser.parseRParen()) - return failure(); - auto attr = parser.getBuilder().getStringAttr(order); - result.addAttribute("order", attr); - } else if (keyword == "inclusive") { - auto attr = UnitAttr::get(parser.getContext()); - result.addAttribute("inclusive", attr); - } else if (keyword == "reduction") { - if (segments[reductionVarPos]) - return allowedOnce(parser, "reduction", opName); - if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars, - reductionVarTypes))) - return failure(); - segments[reductionVarPos] = reductionVars.size(); - } - } - - if (segments[privateClausePos]) { - parser.resolveOperands(privates, privateTypes, privates[0].location, - result.operands); - } - - if (segments[firstprivateClausePos]) { - parser.resolveOperands(firstprivates, firstprivateTypes, - firstprivates[0].location, result.operands); - } - - if (segments[lastprivateClausePos]) { - parser.resolveOperands(lastprivates, lastprivateTypes, - lastprivates[0].location, result.operands); - } - - if (segments[linearClausePos]) { - parser.resolveOperands(linears, linearTypes, linears[0].location, - result.operands); - auto linearStepType = parser.getBuilder().getI32Type(); - SmallVector linearStepTypes(linearSteps.size(), linearStepType); - parser.resolveOperands(linearSteps, linearStepTypes, - linearSteps[0].location, result.operands); - } + SmallVector clauses = { + privateClause, firstprivateClause, lastprivateClause, linearClause, + reductionClause, collapseClause, orderClause, orderedClause, + nowaitClause, scheduleClause}; - if (segments[reductionVarPos]) { - if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, - parser.getNameLoc(), result.operands))) { - return failure(); - } - SmallVector reductions(reductionSymbols.begin(), - reductionSymbols.end()); - result.addAttribute("reductions", - parser.getBuilder().getArrayAttr(reductions)); - } + SmallVector segments{numIVs, numIVs, numIVs}; - if (!schedule.empty()) { - schedule[0] = llvm::toUpper(schedule[0]); - auto attr = parser.getBuilder().getStringAttr(schedule); - result.addAttribute("schedule_val", attr); - if (scheduleChunkSize) { - auto chunkSizeType = parser.getBuilder().getI32Type(); - parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); - } - } + if (failed(parseClauses(parser, result, clauses, segments))) + return failure(); result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr(segments)); @@ -697,37 +1000,21 @@ if (vars.empty()) return; - p << " " << name << "("; - llvm::interleaveComma( - vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); - p << ")"; + p << " " << name; + printOperandAndTypeList(p, vars); }; printDataVars("private", op.private_vars()); printDataVars("firstprivate", op.firstprivate_vars()); printDataVars("lastprivate", op.lastprivate_vars()); - auto linearVars = op.linear_vars(); - auto linearVarsSize = linearVars.size(); - if (linearVarsSize) { - p << " " - << "linear" - << "("; - for (unsigned i = 0; i < linearVarsSize; ++i) { - std::string separator = i == linearVarsSize - 1 ? ")" : ", "; - p << linearVars[i]; - if (op.linear_step_vars().size() > i) - p << " = " << op.linear_step_vars()[i]; - p << " : " << linearVars[i].getType() << separator; - } + if (op.linear_vars().size()) { + p << " linear"; + printLinearClause(p, op, op.linear_vars(), op.linear_step_vars()); } if (auto sched = op.schedule_val()) { - auto schedLower = sched->lower(); - p << " schedule(" << schedLower; - if (auto chunk = op.schedule_chunk_var()) { - p << " = " << chunk; - } - p << ")"; + p << " schedule"; + printScheduleClause(p, op, sched.getValue(), op.schedule_chunk_var()); } if (auto collapse = op.collapse_val()) @@ -740,16 +1027,7 @@ p << " ordered(" << ordered << ")"; } - if (!op.reduction_vars().empty()) { - p << " reduction("; - for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) { - if (i != 0) - p << ", "; - p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : " - << op.reduction_vars()[i].getType(); - } - p << ")"; - } + printReductionVarList(p, op.reductions(), op.reduction_vars()); if (op.inclusive()) { p << " inclusive"; @@ -758,6 +1036,10 @@ p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } +static LogicalResult verifyWsLoopOp(WsLoopOp op) { + return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// @@ -838,127 +1120,6 @@ return op.emitOpError() << "the accumulator is not used by the parent"; } -//===----------------------------------------------------------------------===// -// WsLoopOp -//===----------------------------------------------------------------------===// - -void WsLoopOp::build(OpBuilder &builder, OperationState &state, - ValueRange lowerBound, ValueRange upperBound, - ValueRange step, ArrayRef attributes) { - build(builder, state, TypeRange(), lowerBound, upperBound, step, - /*private_vars=*/ValueRange(), - /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), - /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), - /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, - /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, - /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, - /*inclusive=*/nullptr, /*buildBody=*/false); - state.addAttributes(attributes); -} - -void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, - ValueRange operands, ArrayRef attributes) { - state.addOperands(operands); - state.addAttributes(attributes); - (void)state.addRegion(); - assert(resultTypes.empty() && "mismatched number of return types"); - state.addTypes(resultTypes); -} - -void WsLoopOp::build(OpBuilder &builder, OperationState &result, - TypeRange typeRange, ValueRange lowerBounds, - ValueRange upperBounds, ValueRange steps, - ValueRange privateVars, ValueRange firstprivateVars, - ValueRange lastprivateVars, ValueRange linearVars, - ValueRange linearStepVars, ValueRange reductionVars, - StringAttr scheduleVal, Value scheduleChunkVar, - IntegerAttr collapseVal, UnitAttr nowait, - IntegerAttr orderedVal, StringAttr orderVal, - UnitAttr inclusive, bool buildBody) { - result.addOperands(lowerBounds); - result.addOperands(upperBounds); - result.addOperands(steps); - result.addOperands(privateVars); - result.addOperands(firstprivateVars); - result.addOperands(linearVars); - result.addOperands(linearStepVars); - if (scheduleChunkVar) - result.addOperands(scheduleChunkVar); - - if (scheduleVal) - result.addAttribute("schedule_val", scheduleVal); - if (collapseVal) - result.addAttribute("collapse_val", collapseVal); - if (nowait) - result.addAttribute("nowait", nowait); - if (orderedVal) - result.addAttribute("ordered_val", orderedVal); - if (orderVal) - result.addAttribute("order", orderVal); - if (inclusive) - result.addAttribute("inclusive", inclusive); - result.addAttribute( - WsLoopOp::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr( - {static_cast(lowerBounds.size()), - static_cast(upperBounds.size()), - static_cast(steps.size()), - static_cast(privateVars.size()), - static_cast(firstprivateVars.size()), - static_cast(lastprivateVars.size()), - static_cast(linearVars.size()), - static_cast(linearStepVars.size()), - static_cast(reductionVars.size()), - static_cast(scheduleChunkVar != nullptr ? 1 : 0)})); - - Region *bodyRegion = result.addRegion(); - if (buildBody) { - OpBuilder::InsertionGuard guard(builder); - unsigned numIVs = steps.size(); - SmallVector argTypes(numIVs, steps.getType().front()); - builder.createBlock(bodyRegion, {}, argTypes); - } -} - -static LogicalResult verifyWsLoopOp(WsLoopOp op) { - if (op.getNumReductionVars() != 0) { - if (!op.reductions() || - op.reductions()->size() != op.getNumReductionVars()) { - return op.emitOpError() << "expected as many reduction symbol references " - "as reduction variables"; - } - } else { - if (op.reductions()) - return op.emitOpError() << "unexpected reduction symbol references"; - return success(); - } - - DenseSet accumulators; - for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) { - Value accum = std::get<0>(args); - if (!accumulators.insert(accum).second) { - return op.emitOpError() << "accumulator variable used more than once"; - } - Type varType = accum.getType().cast(); - auto symbolRef = std::get<1>(args).cast(); - auto decl = - SymbolTable::lookupNearestSymbolFrom(op, symbolRef); - if (!decl) { - return op.emitOpError() << "expected symbol reference " << symbolRef - << " to point to a reduction declaration"; - } - - if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { - return op.emitOpError() - << "expected accumulator (" << varType - << ") to be the same type as reduction declaration (" - << decl.getAccumulatorType() << ")"; - } - } - - return success(); -} - static LogicalResult verifyCriticalOp(CriticalOp op) { if (!op.name().hasValue() && op.hint().hasValue() && (op.hint().getValue() != SyncHintKind::none)) diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s func @unknown_clause() { - // expected-error@+1 {{invalid is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{invalid is not a valid clause}} omp.parallel invalid { } @@ -313,3 +313,63 @@ } return } + +// ----- + +func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { + // expected-error @below {{operand used in both private and firstprivate clauses}} + omp.sections private(%data_var1 : memref) firstprivate(%data_var1 : memref) { + omp.terminator + } + return +} + +// ----- + +func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { + // expected-error @below {{operand used in both private and lastprivate clauses}} + omp.sections private(%data_var1 : memref) lastprivate(%data_var1 : memref) { + omp.terminator + } + return +} + +// ----- + +func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { + // expected-error @below {{operand used in both private and lastprivate clauses}} + omp.sections private(%data_var1 : memref, %data_var2 : memref) lastprivate(%data_var3 : memref, %data_var2 : memref) { + omp.terminator + } + return +} + +// ----- + +func @omp_sections(%data_var : memref) -> () { + // expected-error @below {{expected equal sizes for allocate and allocator variables}} + "omp.sections" (%data_var) ({ + omp.terminator + }) {operand_segment_sizes = dense<[0,0,0,0,1,0]> : vector<6xi32>} : (memref) -> () + return +} + +// ----- + +func @omp_sections(%data_var : memref) -> () { + // expected-error @below {{expected as many reduction symbol references as reduction variables}} + "omp.sections" (%data_var) ({ + omp.terminator + }) {operand_segment_sizes = dense<[0,0,0,1,0,0]> : vector<6xi32>} : (memref) -> () + return +} + +// ----- + +func @omp_sections(%data_var : memref) -> () { + // expected-error @below {{expected omp.section op or terminator op inside region}} + omp.sections { + "test.payload" () : () -> () + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -126,6 +126,129 @@ return } +func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref, %redn_var : !llvm.ptr) { + + // CHECK: omp.sections private(%{{.*}} : memref) { + "omp.sections" (%data_var1) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[1,0,0,0,0,0]> : vector<6xi32>} : (memref) -> () + + // CHECK: omp.sections firstprivate(%{{.*}} : memref) { + "omp.sections" (%data_var1) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[0,1,0,0,0,0]> : vector<6xi32>} : (memref) -> () + + // CHECK: omp.sections lastprivate(%{{.*}} : memref) { + "omp.sections" (%data_var1) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[0,0,1,0,0,0]> : vector<6xi32>} : (memref) -> () + + // CHECK: omp.sections private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) { + "omp.sections" (%data_var1, %data_var2, %data_var3) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[1,1,1,0,0,0]> : vector<6xi32>} : (memref, memref, memref) -> () + + // CHECK: omp.sections allocate(%{{.*}} : memref -> %{{.*}} : memref) + "omp.sections" (%data_var1, %data_var1) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[0,0,0,0,1,1]> : vector<6xi32>} : (memref, memref) -> () + + // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) + "omp.sections" (%redn_var) ({ + // CHECK: omp.terminator + omp.terminator + }) {operand_segment_sizes = dense<[0,0,0,1,0,0]> : vector<6xi32>, reductions=[@add_f32]} : (!llvm.ptr) -> () + + // CHECK: omp.sections private(%{{.*}} : memref) { + omp.sections private(%data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections firstprivate(%{{.*}} : memref) + omp.sections firstprivate(%data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections lastprivate(%{{.*}} : memref) + omp.sections lastprivate(%data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) { + omp.sections private(%data_var1 : memref) firstprivate(%data_var2 : memref) lastprivate(%data_var3 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) lastprivate(%{{.*}} : memref) { + omp.sections lastprivate(%data_var1 : memref) firstprivate(%data_var2 : memref) private(%data_var3 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections private(%{{.*}} : memref) nowait { + omp.sections nowait private(%data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections firstprivate(%{{.*}} : memref, %{{.*}} : memref) lastprivate(%{{.*}} : memref) { + omp.sections firstprivate(%data_var1 : memref, %data_var2 : memref) lastprivate(%data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + %c1 = constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + + // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) { + omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections allocate(%{{.*}} : memref -> %{{.*}} : memref) + omp.sections allocate(%data_var1 : memref -> %data_var1 : memref) { + // CHECK: omp.terminator + omp.terminator + } + + // CHECK: omp.sections nowait + omp.sections nowait { + // CHECK: omp.section + omp.section { + // CHECK: %{{.*}} = "test.payload"() : () -> i32 + %1 = "test.payload"() : () -> i32 + // CHECK: %{{.*}} = "test.payload"() : () -> i32 + %2 = "test.payload"() : () -> i32 + // CHECK: %{{.*}} = "test.payload"(%{{.*}}, %{{.*}}) : (i32, i32) -> i32 + %3 = "test.payload"(%1, %2) : (i32, i32) -> i32 + } + // CHECK: omp.section + omp.section { + // CHECK: %{{.*}} = "test.payload"(%{{.*}}) : (!llvm.ptr) -> i32 + %1 = "test.payload"(%0) : (!llvm.ptr) -> i32 + } + // CHECK: omp.section + omp.section { + // CHECK: "test.payload"(%{{.*}}) : (!llvm.ptr) -> () + "test.payload"(%0) : (!llvm.ptr) -> () + } + // CHECK: omp.terminator + omp.terminator + } + + return +} + // CHECK-LABEL: omp_wsloop func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref, %linear_var : i32, %chunk_var : i32) -> () {