diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -211,8 +211,7 @@ Variadic:$linear_vars, Variadic:$linear_step_vars, Variadic:$reduction_vars, - OptionalAttr>:$reductions, + OptionalAttr:$reductions, OptionalAttr:$schedule_val, Optional:$schedule_chunk_var, Confined, [IntMinValue<0>]>:$collapse_val, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -67,6 +68,10 @@ state.addAttributes(attributes); } +//===----------------------------------------------------------------------===// +// Parser and printer for Operand and type list +//===----------------------------------------------------------------------===// + /// Parse a list of operands with types. /// /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` @@ -89,9 +94,30 @@ }); } +/// Print an operand and type list with parentheses +static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { + p << "("; + llvm::interleaveComma( + operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); + p << ") "; +} + +/// Print data variables corresponding to a data-sharing clause `name` +static void printDataVars(OpAsmPrinter &p, OperandRange operands, + StringRef name) { + if (operands.size()) { + p << name; + printOperandAndTypeList(p, operands); + } +} + +//===----------------------------------------------------------------------===// +// Parser and printer for Allocate Clause +//===----------------------------------------------------------------------===// + /// Parse an allocate clause with allocators and a list of operands with types. /// -/// operand-and-type-list ::= `(` allocate-operand-list `)` +/// allocate ::= `allocate` `(` allocate-operand-list `)` /// allocate-operand-list :: = allocate-operand | /// allocator-operand `,` allocate-operand-list /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type @@ -122,6 +148,21 @@ }); } +/// Print allocate clause +static void printAllocateAndAllocator(OpAsmPrinter &p, + OperandRange varsAllocate, + OperandRange varsAllocator) { + if (varsAllocate.empty()) + return; + + p << "allocate("; + for (unsigned i = 0; i < varsAllocate.size(); ++i) { + std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; + p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; + p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; + } +} + static LogicalResult verifyParallelOp(ParallelOp op) { if (op.allocate_vars().size() != op.allocators_vars().size()) return op.emitError( @@ -130,250 +171,31 @@ } static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { + p << " "; if (auto ifCond = op.if_expr_var()) - p << " if(" << ifCond << " : " << ifCond.getType() << ")"; + p << "if(" << ifCond << " : " << ifCond.getType() << ") "; if (auto threads = op.num_threads_var()) - p << " num_threads(" << threads << " : " << threads.getType() << ")"; - - // Print private, firstprivate, shared and copyin parameters - auto printDataVars = [&p](StringRef name, OperandRange vars) { - if (vars.size()) { - p << " " << name << "("; - for (unsigned i = 0; i < vars.size(); ++i) { - std::string separator = i == vars.size() - 1 ? ")" : ", "; - p << vars[i] << " : " << vars[i].getType() << separator; - } - } - }; + p << "num_threads(" << threads << " : " << threads.getType() << ") "; - // Print allocator and allocate parameters - auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, - OperandRange varsAllocator) { - if (varsAllocate.empty()) - return; - - p << " allocate("; - for (unsigned i = 0; i < varsAllocate.size(); ++i) { - std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; - p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; - p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; - } - }; - - printDataVars("private", op.private_vars()); - printDataVars("firstprivate", op.firstprivate_vars()); - printDataVars("shared", op.shared_vars()); - printDataVars("copyin", op.copyin_vars()); - printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); + printDataVars(p, op.private_vars(), "private"); + printDataVars(p, op.firstprivate_vars(), "firstprivate"); + printDataVars(p, op.shared_vars(), "shared"); + printDataVars(p, op.copyin_vars(), "copyin"); + printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); if (auto def = op.default_val()) - p << " default(" << def->drop_front(3) << ")"; + p << "default(" << def->drop_front(3) << ") "; if (auto bind = op.proc_bind_val()) - p << " proc_bind(" << bind << ")"; + p << "proc_bind(" << bind << ") "; p.printRegion(op.getRegion()); } -/// Emit an error if the same clause is present more than once on an operation. -static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, - StringRef operation) { - return parser.emitError(parser.getNameLoc()) - << " at most one " << clause << " clause can appear on the " - << operation << " operation"; -} - -/// Parses a parallel operation. -/// -/// operation ::= `omp.parallel` clause-list -/// clause-list ::= clause | clause clause-list -/// clause ::= if | numThreads | private | firstprivate | shared | copyin | -/// default | procBind -/// if ::= `if` `(` ssa-id `)` -/// numThreads ::= `num_threads` `(` ssa-id-and-type `)` -/// private ::= `private` operand-and-type-list -/// firstprivate ::= `firstprivate` operand-and-type-list -/// shared ::= `shared` operand-and-type-list -/// copyin ::= `copyin` operand-and-type-list -/// allocate ::= `allocate` operand-and-type `->` operand-and-type-list -/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) -/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` -/// -/// Note that each clause can only appear once in the clase-list. -static ParseResult parseParallelOp(OpAsmParser &parser, - OperationState &result) { - std::pair ifCond; - std::pair numThreads; - SmallVector privates; - SmallVector privateTypes; - SmallVector firstprivates; - SmallVector firstprivateTypes; - SmallVector shareds; - SmallVector sharedTypes; - SmallVector copyins; - SmallVector copyinTypes; - SmallVector allocates; - SmallVector allocateTypes; - SmallVector allocators; - SmallVector allocatorTypes; - std::array segments{0, 0, 0, 0, 0, 0, 0, 0}; - StringRef keyword; - bool defaultVal = false; - bool procBind = false; - - const int ifClausePos = 0; - const int numThreadsClausePos = 1; - const int privateClausePos = 2; - const int firstprivateClausePos = 3; - const int sharedClausePos = 4; - const int copyinClausePos = 5; - const int allocateClausePos = 6; - const int allocatorPos = 7; - const StringRef opName = result.name.getStringRef(); - - while (succeeded(parser.parseOptionalKeyword(&keyword))) { - if (keyword == "if") { - // Fail if there was already another if condition. - if (segments[ifClausePos]) - return allowedOnce(parser, "if", opName); - if (parser.parseLParen() || parser.parseOperand(ifCond.first) || - parser.parseColonType(ifCond.second) || parser.parseRParen()) - return failure(); - segments[ifClausePos] = 1; - } else if (keyword == "num_threads") { - // Fail if there was already another num_threads clause. - if (segments[numThreadsClausePos]) - return allowedOnce(parser, "num_threads", opName); - if (parser.parseLParen() || parser.parseOperand(numThreads.first) || - parser.parseColonType(numThreads.second) || parser.parseRParen()) - return failure(); - segments[numThreadsClausePos] = 1; - } else if (keyword == "private") { - // Fail if there was already another private clause. - if (segments[privateClausePos]) - return allowedOnce(parser, "private", opName); - if (parseOperandAndTypeList(parser, privates, privateTypes)) - return failure(); - segments[privateClausePos] = privates.size(); - } else if (keyword == "firstprivate") { - // Fail if there was already another firstprivate clause. - if (segments[firstprivateClausePos]) - return allowedOnce(parser, "firstprivate", opName); - if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) - return failure(); - segments[firstprivateClausePos] = firstprivates.size(); - } else if (keyword == "shared") { - // Fail if there was already another shared clause. - if (segments[sharedClausePos]) - return allowedOnce(parser, "shared", opName); - if (parseOperandAndTypeList(parser, shareds, sharedTypes)) - return failure(); - segments[sharedClausePos] = shareds.size(); - } else if (keyword == "copyin") { - // Fail if there was already another copyin clause. - if (segments[copyinClausePos]) - return allowedOnce(parser, "copyin", opName); - if (parseOperandAndTypeList(parser, copyins, copyinTypes)) - return failure(); - segments[copyinClausePos] = copyins.size(); - } else if (keyword == "allocate") { - // Fail if there was already another allocate clause. - if (segments[allocateClausePos]) - return allowedOnce(parser, "allocate", opName); - if (parseAllocateAndAllocator(parser, allocates, allocateTypes, - allocators, allocatorTypes)) - return failure(); - segments[allocateClausePos] = allocates.size(); - segments[allocatorPos] = allocators.size(); - } else if (keyword == "default") { - // Fail if there was already another default clause. - if (defaultVal) - return allowedOnce(parser, "default", opName); - defaultVal = true; - StringRef defval; - if (parser.parseLParen() || parser.parseKeyword(&defval) || - parser.parseRParen()) - return failure(); - // The def prefix is required for the attribute as "private" is a keyword - // in C++. - auto attr = parser.getBuilder().getStringAttr("def" + defval); - result.addAttribute("default_val", attr); - } else if (keyword == "proc_bind") { - // Fail if there was already another proc_bind clause. - if (procBind) - return allowedOnce(parser, "proc_bind", opName); - procBind = true; - StringRef bind; - if (parser.parseLParen() || parser.parseKeyword(&bind) || - parser.parseRParen()) - return failure(); - auto attr = parser.getBuilder().getStringAttr(bind); - result.addAttribute("proc_bind_val", attr); - } else { - return parser.emitError(parser.getNameLoc()) - << keyword << " is not a valid clause for the " << opName - << " operation"; - } - } - - // Add if parameter. - if (segments[ifClausePos] && - parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) - return failure(); - - // Add num_threads parameter. - if (segments[numThreadsClausePos] && - parser.resolveOperand(numThreads.first, numThreads.second, - result.operands)) - return failure(); - - // Add private parameters. - if (segments[privateClausePos] && - parser.resolveOperands(privates, privateTypes, privates[0].location, - result.operands)) - return failure(); - - // Add firstprivate parameters. - if (segments[firstprivateClausePos] && - parser.resolveOperands(firstprivates, firstprivateTypes, - firstprivates[0].location, result.operands)) - return failure(); - - // Add shared parameters. - if (segments[sharedClausePos] && - parser.resolveOperands(shareds, sharedTypes, shareds[0].location, - result.operands)) - return failure(); - - // Add copyin parameters. - if (segments[copyinClausePos] && - parser.resolveOperands(copyins, copyinTypes, copyins[0].location, - result.operands)) - return failure(); - - // Add allocate parameters. - if (segments[allocateClausePos] && - parser.resolveOperands(allocates, allocateTypes, allocates[0].location, - result.operands)) - return failure(); - - // Add allocator parameters. - if (segments[allocatorPos] && - parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, - result.operands)) - return failure(); - - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr(segments)); - - Region *body = result.addRegion(); - SmallVector regionArgs; - SmallVector regionArgTypes; - if (parser.parseRegion(*body, regionArgs, regionArgTypes)) - return failure(); - return success(); -} +//===----------------------------------------------------------------------===// +// Parser and printer for Linear Clause +//===----------------------------------------------------------------------===// /// linear ::= `linear` `(` linear-list `)` /// linear-list := linear-val | linear-val linear-list @@ -405,6 +227,24 @@ return success(); } +/// Print Linear Clause +static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, + OperandRange linearStepVars) { + size_t linearVarsSize = linearVars.size(); + p << "("; + for (unsigned i = 0; i < linearVarsSize; ++i) { + std::string separator = i == linearVarsSize - 1 ? ") " : ", "; + p << linearVars[i]; + if (linearStepVars.size() > i) + p << " = " << linearStepVars[i]; + p << " : " << linearVars[i].getType() << separator; + } +} + +//===----------------------------------------------------------------------===// +// Parser and printer for Schedule Clause +//===----------------------------------------------------------------------===// + /// schedule ::= `schedule` `(` sched-list `)` /// sched-list ::= sched-val | sched-val sched-list /// sched-val ::= sched-with-chunk | sched-wo-chunk @@ -442,7 +282,21 @@ return success(); } -/// reduction-init ::= `reduction` `(` reduction-entry-list `)` +/// Print schedule clause +static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, + Value scheduleChunkVar) { + std::string schedLower = sched.lower(); + p << "(" << schedLower; + if (scheduleChunkVar) + p << " = " << scheduleChunkVar; + p << ") "; +} + +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for ReductionVarList +//===----------------------------------------------------------------------===// + +/// reduction ::= `reduction` `(` reduction-entry-list `)` /// reduction-entry-list ::= reduction-entry /// | reduction-entry-list `,` reduction-entry /// reduction-entry ::= symbol-ref `->` ssa-id `:` type @@ -463,209 +317,392 @@ return parser.parseRParen(); } -/// Parses an OpenMP Workshare Loop operation -/// -/// operation ::= `omp.wsloop` loop-control clause-list -/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds -/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps -/// steps := `step` `(`ssa-id-list`)` -/// clause-list ::= clause | empty | clause-list -/// clause ::= private | firstprivate | lastprivate | linear | schedule | -// collapse | nowait | ordered | order | inclusive -/// private ::= `private` `(` ssa-id-and-type-list `)` -/// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` -/// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` +/// Print Reduction clause +static void printReductionVarList(OpAsmPrinter &p, + Optional reductions, + OperandRange reduction_vars) { + for (unsigned i = 0, e = reductions->size(); i < e; ++i) { + if (i != 0) + p << ", "; + p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " + << reduction_vars[i].getType(); + } + p << ") "; +} + +/// Verifies Reduction Clause +static LogicalResult verifyReductionVarList(Operation *op, + Optional reductions, + OperandRange reduction_vars) { + if (reduction_vars.size() != 0) { + if (!reductions || reductions->size() != reduction_vars.size()) + return op->emitOpError() + << "expected as many reduction symbol references " + "as reduction variables"; + } else { + if (reductions) + return op->emitOpError() << "unexpected reduction symbol references"; + return success(); + } + + DenseSet accumulators; + for (auto args : llvm::zip(reduction_vars, *reductions)) { + Value accum = std::get<0>(args); + + if (!accumulators.insert(accum).second) + return op->emitOpError() << "accumulator variable used more than once"; + + Type varType = accum.getType().cast(); + auto symbolRef = std::get<1>(args).cast(); + auto decl = + SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + if (!decl) + return op->emitOpError() << "expected symbol reference " << symbolRef + << " to point to a reduction declaration"; + + if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) + return op->emitOpError() + << "expected accumulator (" << varType + << ") to be the same type as reduction declaration (" + << decl.getAccumulatorType() << ")"; + } + + return success(); +} + +enum ClauseType { + ifClause, + numThreadsClause, + privateClause, + firstprivateClause, + lastprivateClause, + sharedClause, + copyinClause, + allocateClause, + defaultClause, + procBindClause, + reductionClause, + nowaitClause, + linearClause, + scheduleClause, + collapseClause, + orderClause, + orderedClause, + inclusiveClause, + COUNT +}; + +//===----------------------------------------------------------------------===// +// Parser for Clause List +//===----------------------------------------------------------------------===// + +/// Parse a list of clauses. The clauses can appear in any order, but their +/// operand segment indices are in the same order that they are passed in the +/// `clauses` list. The operand segments are added over the prevSegments + +/// clause-list ::= clause clause-list | empty +/// clause ::= if | num-threads | private | firstprivate | lastprivate | +/// shared | copyin | allocate | default | proc-bind | reduction | +/// nowait | linear | schedule | collapse | order | ordered | +/// inclusive +/// if ::= `if` `(` ssa-id-and-type `)` +/// num-threads ::= `num_threads` `(` ssa-id-and-type `)` +/// private ::= `private` operand-and-type-list +/// firstprivate ::= `firstprivate` operand-and-type-list +/// lastprivate ::= `lastprivate` operand-and-type-list +/// shared ::= `shared` operand-and-type-list +/// copyin ::= `copyin` operand-and-type-list +/// allocate ::= `allocate` `(` allocate-operand-list `)` +/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) +/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` +/// reduction ::= `reduction` `(` reduction-entry-list `)` +/// nowait ::= `nowait` /// linear ::= `linear` `(` linear-list `)` /// schedule ::= `schedule` `(` sched-list `)` /// collapse ::= `collapse` `(` ssa-id-and-type `)` -/// nowait ::= `nowait` -/// ordered ::= `ordered` `(` ssa-id-and-type `)` /// order ::= `order` `(` `concurrent` `)` +/// ordered ::= `ordered` `(` ssa-id-and-type `)` /// inclusive ::= `inclusive` /// -static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { - Type loopVarType; - int numIVs; +/// Note that each clause can only appear once in the clase-list. +static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &clauses, + SmallVectorImpl &segments) { - // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren)) - return failure(); + // Check done[clause] to see if it has been parsed already + llvm::BitVector done(ClauseType::COUNT, false); - numIVs = static_cast(ivs.size()); + // See pos[clause] to get position of clause in operand segments + SmallVector pos(ClauseType::COUNT, -1); - if (parser.parseColonType(loopVarType)) - return failure(); + // Stores the last parsed clause keyword + StringRef clauseKeyword; + StringRef opName = result.name.getStringRef(); - // Parse loop bounds. - SmallVector lower; - if (parser.parseEqual() || - parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(lower, loopVarType, result.operands)) - return failure(); + // Containers for storing operands, types and attributes for various clauses + std::pair ifCond; + std::pair numThreads; - SmallVector upper; - if (parser.parseKeyword("to") || - parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(upper, loopVarType, result.operands)) - return failure(); + SmallVector privates, firstprivates, lastprivates, + shareds, copyins; + SmallVector privateTypes, firstprivateTypes, lastprivateTypes, + sharedTypes, copyinTypes; - // Parse step values. - SmallVector steps; - if (parser.parseKeyword("step") || - parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(steps, loopVarType, result.operands)) - return failure(); + SmallVector allocates, allocators; + SmallVector allocateTypes, allocatorTypes; - SmallVector privates; - SmallVector privateTypes; - SmallVector firstprivates; - SmallVector firstprivateTypes; - SmallVector lastprivates; - SmallVector lastprivateTypes; - SmallVector linears; - SmallVector linearTypes; - SmallVector linearSteps; SmallVector reductionSymbols; SmallVector reductionVars; SmallVector reductionVarTypes; + + SmallVector linears; + SmallVector linearTypes; + SmallVector linearSteps; + SmallString<8> schedule; Optional scheduleChunkSize; - const StringRef opName = result.name.getStringRef(); - StringRef keyword; + // Compute the position of clauses in operand segments + int currPos = 0; + for (ClauseType clause : clauses) { - enum SegmentPos { - lbPos = 0, - ubPos, - stepPos, - privateClausePos, - firstprivateClausePos, - lastprivateClausePos, - linearClausePos, - linearStepPos, - reductionVarPos, - scheduleClausePos, + // Skip the following clauses - they do not take any position in operand + // segments + if (clause == defaultClause || clause == procBindClause || + clause == nowaitClause || clause == collapseClause || + clause == orderClause || clause == orderedClause || + clause == inclusiveClause) + continue; + + pos[clause] = currPos++; + + // For the following clauses, two positions are reserved in the operand + // segments + if (clause == allocateClause || clause == linearClause) + currPos++; + } + + SmallVector clauseSegments(currPos); + + // Helper function to check if a clause is allowed/repeated or not + auto checkAllowed = [&](ClauseType clause, + bool allowRepeat = false) -> ParseResult { + if (!llvm::is_contained(clauses, clause)) + return parser.emitError(parser.getCurrentLocation()) + << clauseKeyword << "is not a valid clause for the " << opName + << " operation"; + if (done[clause] && !allowRepeat) + return parser.emitError(parser.getCurrentLocation()) + << "at most one " << clauseKeyword << " clause can appear on the " + << opName << " operation"; + done[clause] = true; + return success(); }; - std::array segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0}; - while (succeeded(parser.parseOptionalKeyword(&keyword))) { - if (keyword == "private") { - if (segments[privateClausePos]) - return allowedOnce(parser, "private", opName); - if (parseOperandAndTypeList(parser, privates, privateTypes)) + while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { + if (clauseKeyword == "if") { + if (checkAllowed(ifClause) || parser.parseLParen() || + parser.parseOperand(ifCond.first) || + parser.parseColonType(ifCond.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[ifClause]] = 1; + } else if (clauseKeyword == "num_threads") { + if (checkAllowed(numThreadsClause) || parser.parseLParen() || + parser.parseOperand(numThreads.first) || + parser.parseColonType(numThreads.second) || parser.parseRParen()) + return failure(); + clauseSegments[pos[numThreadsClause]] = 1; + } else if (clauseKeyword == "private") { + if (checkAllowed(privateClause) || + parseOperandAndTypeList(parser, privates, privateTypes)) + return failure(); + clauseSegments[pos[privateClause]] = privates.size(); + } else if (clauseKeyword == "firstprivate") { + if (checkAllowed(firstprivateClause) || + parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + return failure(); + clauseSegments[pos[firstprivateClause]] = firstprivates.size(); + } else if (clauseKeyword == "lastprivate") { + if (checkAllowed(lastprivateClause) || + parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) + return failure(); + clauseSegments[pos[lastprivateClause]] = lastprivates.size(); + } else if (clauseKeyword == "shared") { + if (checkAllowed(sharedClause) || + parseOperandAndTypeList(parser, shareds, sharedTypes)) + return failure(); + clauseSegments[pos[sharedClause]] = shareds.size(); + } else if (clauseKeyword == "copyin") { + if (checkAllowed(copyinClause) || + parseOperandAndTypeList(parser, copyins, copyinTypes)) + return failure(); + clauseSegments[pos[copyinClause]] = copyins.size(); + } else if (clauseKeyword == "allocate") { + if (checkAllowed(allocateClause) || + parseAllocateAndAllocator(parser, allocates, allocateTypes, + allocators, allocatorTypes)) return failure(); - segments[privateClausePos] = privates.size(); - } else if (keyword == "firstprivate") { - // fail if there was already another firstprivate clause - if (segments[firstprivateClausePos]) - return allowedOnce(parser, "firstprivate", opName); - if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + clauseSegments[pos[allocateClause]] = allocates.size(); + clauseSegments[pos[allocateClause] + 1] = allocators.size(); + } else if (clauseKeyword == "default") { + StringRef defval; + if (checkAllowed(defaultClause) || parser.parseLParen() || + parser.parseKeyword(&defval) || parser.parseRParen()) return failure(); - segments[firstprivateClausePos] = firstprivates.size(); - } else if (keyword == "lastprivate") { - // fail if there was already another shared clause - if (segments[lastprivateClausePos]) - return allowedOnce(parser, "lastprivate", opName); - if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) + // The def prefix is required for the attribute as "private" is a keyword + // in C++. + auto attr = parser.getBuilder().getStringAttr("def" + defval); + result.addAttribute("default_val", attr); + } else if (clauseKeyword == "proc_bind") { + StringRef bind; + if (checkAllowed(procBindClause) || parser.parseLParen() || + parser.parseKeyword(&bind) || parser.parseRParen()) return failure(); - segments[lastprivateClausePos] = lastprivates.size(); - } else if (keyword == "linear") { - // fail if there was already another linear clause - if (segments[linearClausePos]) - return allowedOnce(parser, "linear", opName); - if (parseLinearClause(parser, linears, linearTypes, linearSteps)) + auto attr = parser.getBuilder().getStringAttr(bind); + result.addAttribute("proc_bind_val", attr); + } else if (clauseKeyword == "reduction") { + if (checkAllowed(reductionClause) || + parseReductionVarList(parser, reductionSymbols, reductionVars, + reductionVarTypes)) + return failure(); + clauseSegments[pos[reductionClause]] = reductionVars.size(); + } else if (clauseKeyword == "nowait") { + if (checkAllowed(nowaitClause)) + return failure(); + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("nowait", attr); + } else if (clauseKeyword == "linear") { + if (checkAllowed(linearClause) || + parseLinearClause(parser, linears, linearTypes, linearSteps)) return failure(); - segments[linearClausePos] = linears.size(); - segments[linearStepPos] = linearSteps.size(); - } else if (keyword == "schedule") { - if (!schedule.empty()) - return allowedOnce(parser, "schedule", opName); - if (parseScheduleClause(parser, schedule, scheduleChunkSize)) + clauseSegments[pos[linearClause]] = linears.size(); + clauseSegments[pos[linearClause] + 1] = linearSteps.size(); + } else if (clauseKeyword == "schedule") { + if (checkAllowed(scheduleClause) || + parseScheduleClause(parser, schedule, scheduleChunkSize)) return failure(); if (scheduleChunkSize) { - segments[scheduleClausePos] = 1; + clauseSegments[pos[scheduleClause]] = 1; } - } else if (keyword == "collapse") { + } else if (clauseKeyword == "collapse") { auto type = parser.getBuilder().getI64Type(); mlir::IntegerAttr attr; - if (parser.parseLParen() || parser.parseAttribute(attr, type) || - parser.parseRParen()) + if (checkAllowed(collapseClause) || parser.parseLParen() || + parser.parseAttribute(attr, type) || parser.parseRParen()) return failure(); result.addAttribute("collapse_val", attr); - } else if (keyword == "nowait") { - auto attr = UnitAttr::get(parser.getContext()); - result.addAttribute("nowait", attr); - } else if (keyword == "ordered") { + } else if (clauseKeyword == "ordered") { mlir::IntegerAttr attr; + if (checkAllowed(orderedClause)) + return failure(); if (succeeded(parser.parseOptionalLParen())) { auto type = parser.getBuilder().getI64Type(); - if (parser.parseAttribute(attr, type)) - return failure(); - if (parser.parseRParen()) + if (parser.parseAttribute(attr, type) || parser.parseRParen()) return failure(); } else { // Use 0 to represent no ordered parameter was specified attr = parser.getBuilder().getI64IntegerAttr(0); } result.addAttribute("ordered_val", attr); - } else if (keyword == "order") { + } else if (clauseKeyword == "order") { StringRef order; - if (parser.parseLParen() || parser.parseKeyword(&order) || - parser.parseRParen()) + if (checkAllowed(orderClause) || parser.parseLParen() || + parser.parseKeyword(&order) || parser.parseRParen()) return failure(); auto attr = parser.getBuilder().getStringAttr(order); result.addAttribute("order", attr); - } else if (keyword == "inclusive") { - auto attr = UnitAttr::get(parser.getContext()); - result.addAttribute("inclusive", attr); - } else if (keyword == "reduction") { - if (segments[reductionVarPos]) - return allowedOnce(parser, "reduction", opName); - if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars, - reductionVarTypes))) + } else if (clauseKeyword == "inclusive") { + if (checkAllowed(inclusiveClause)) return failure(); - segments[reductionVarPos] = reductionVars.size(); + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("inclusive", attr); + } else { + return parser.emitError(parser.getNameLoc()) + << clauseKeyword << " is not a valid clause"; } } - if (segments[privateClausePos]) { - parser.resolveOperands(privates, privateTypes, privates[0].location, - result.operands); - } + // Add if parameter. + if (done[ifClause] && clauseSegments[pos[ifClause]] && + failed( + parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) + return failure(); - if (segments[firstprivateClausePos]) { - parser.resolveOperands(firstprivates, firstprivateTypes, - firstprivates[0].location, result.operands); - } + // Add num_threads parameter. + if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && + failed(parser.resolveOperand(numThreads.first, numThreads.second, + result.operands))) + return failure(); - if (segments[lastprivateClausePos]) { - parser.resolveOperands(lastprivates, lastprivateTypes, - lastprivates[0].location, result.operands); - } + // Add private parameters. + if (done[privateClause] && clauseSegments[pos[privateClause]] && + failed(parser.resolveOperands(privates, privateTypes, + privates[0].location, result.operands))) + return failure(); - if (segments[linearClausePos]) { - parser.resolveOperands(linears, linearTypes, linears[0].location, - result.operands); - auto linearStepType = parser.getBuilder().getI32Type(); - SmallVector linearStepTypes(linearSteps.size(), linearStepType); - parser.resolveOperands(linearSteps, linearStepTypes, - linearSteps[0].location, result.operands); - } + // Add firstprivate parameters. + if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && + failed(parser.resolveOperands(firstprivates, firstprivateTypes, + firstprivates[0].location, + result.operands))) + return failure(); + + // Add lastprivate parameters. + if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && + failed(parser.resolveOperands(lastprivates, lastprivateTypes, + lastprivates[0].location, result.operands))) + return failure(); - if (segments[reductionVarPos]) { + // Add shared parameters. + if (done[sharedClause] && clauseSegments[pos[sharedClause]] && + failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, + result.operands))) + return failure(); + + // Add copyin parameters. + if (done[copyinClause] && clauseSegments[pos[copyinClause]] && + failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, + result.operands))) + return failure(); + + // Add allocate parameters. + if (done[allocateClause] && clauseSegments[pos[allocateClause]] && + failed(parser.resolveOperands(allocates, allocateTypes, + allocates[0].location, result.operands))) + return failure(); + + // Add allocator parameters. + if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && + failed(parser.resolveOperands(allocators, allocatorTypes, + allocators[0].location, result.operands))) + return failure(); + + // Add reduction parameters and symbols + if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, - parser.getNameLoc(), result.operands))) { + parser.getNameLoc(), result.operands))) return failure(); - } + SmallVector reductions(reductionSymbols.begin(), reductionSymbols.end()); result.addAttribute("reductions", parser.getBuilder().getArrayAttr(reductions)); } - if (!schedule.empty()) { + // Add linear parameters + if (done[linearClause] && clauseSegments[pos[linearClause]]) { + auto linearStepType = parser.getBuilder().getI32Type(); + SmallVector linearStepTypes(linearSteps.size(), linearStepType); + if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, + result.operands)) || + failed(parser.resolveOperands(linearSteps, linearStepTypes, + linearSteps[0].location, + result.operands))) + return failure(); + } + + // Add schedule parameters + if (done[scheduleClause] && !schedule.empty()) { schedule[0] = llvm::toUpper(schedule[0]); auto attr = parser.getBuilder().getStringAttr(schedule); result.addAttribute("schedule_val", attr); @@ -675,6 +712,91 @@ } } + segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); + + return success(); +} + +/// Parses a parallel operation. +/// +/// operation ::= `omp.parallel` clause-list +/// clause-list ::= clause | clause clause-list +/// clause ::= if | num-threads | private | firstprivate | shared | copyin | +/// allocate | default | proc-bind +/// +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + SmallVector clauses = { + ifClause, numThreadsClause, privateClause, + firstprivateClause, sharedClause, copyinClause, + allocateClause, defaultClause, procBindClause}; + + SmallVector segments; + + if (failed(parseClauses(parser, result, clauses, segments))) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + SmallVector regionArgs; + SmallVector regionArgTypes; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + +/// Parses an OpenMP Workshare Loop operation +/// +/// wsloop ::= `omp.wsloop` loop-control clause-list +/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds +/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps +/// steps := `step` `(`ssa-id-list`)` +/// clause-list ::= clause clause-list | empty +/// clause ::= private | firstprivate | lastprivate | linear | schedule | +// collapse | nowait | ordered | order | inclusive | reduction +static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { + + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren)) + return failure(); + + int numIVs = static_cast(ivs.size()); + Type loopVarType; + if (parser.parseColonType(loopVarType)) + return failure(); + + // Parse loop bounds. + SmallVector lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, loopVarType, result.operands)) + return failure(); + + SmallVector upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, loopVarType, result.operands)) + return failure(); + + // Parse step values. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, loopVarType, result.operands)) + return failure(); + + SmallVector clauses = { + privateClause, firstprivateClause, lastprivateClause, linearClause, + reductionClause, collapseClause, orderClause, orderedClause, + nowaitClause, scheduleClause}; + SmallVector segments{numIVs, numIVs, numIVs}; + if (failed(parseClauses(parser, result, clauses, segments))) + return failure(); + result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr(segments)); @@ -690,69 +812,38 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { auto args = op.getRegion().front().getArguments(); p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() - << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; + << ") to (" << op.upperBound() << ") step (" << op.step() << ") "; - // Print private, firstprivate, shared and copyin parameters - auto printDataVars = [&p](StringRef name, OperandRange vars) { - if (vars.empty()) - return; + printDataVars(p, op.private_vars(), "private"); + printDataVars(p, op.firstprivate_vars(), "firstprivate"); + printDataVars(p, op.lastprivate_vars(), "lastprivate"); - p << " " << name << "("; - llvm::interleaveComma( - vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); - p << ")"; - }; - printDataVars("private", op.private_vars()); - printDataVars("firstprivate", op.firstprivate_vars()); - printDataVars("lastprivate", op.lastprivate_vars()); - - auto linearVars = op.linear_vars(); - auto linearVarsSize = linearVars.size(); - if (linearVarsSize) { - p << " " - << "linear" - << "("; - for (unsigned i = 0; i < linearVarsSize; ++i) { - std::string separator = i == linearVarsSize - 1 ? ")" : ", "; - p << linearVars[i]; - if (op.linear_step_vars().size() > i) - p << " = " << op.linear_step_vars()[i]; - p << " : " << linearVars[i].getType() << separator; - } + if (op.linear_vars().size()) { + p << "linear"; + printLinearClause(p, op.linear_vars(), op.linear_step_vars()); } if (auto sched = op.schedule_val()) { - auto schedLower = sched->lower(); - p << " schedule(" << schedLower; - if (auto chunk = op.schedule_chunk_var()) { - p << " = " << chunk; - } - p << ")"; + p << "schedule"; + printScheduleClause(p, sched.getValue(), op.schedule_chunk_var()); } if (auto collapse = op.collapse_val()) - p << " collapse(" << collapse << ")"; + p << "collapse(" << collapse << ") "; if (op.nowait()) - p << " nowait"; + p << "nowait "; - if (auto ordered = op.ordered_val()) { - p << " ordered(" << ordered << ")"; - } + if (auto ordered = op.ordered_val()) + p << "ordered(" << ordered << ") "; if (!op.reduction_vars().empty()) { - p << " reduction("; - for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) { - if (i != 0) - p << ", "; - p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : " - << op.reduction_vars()[i].getType(); - } - p << ")"; + p << "reduction("; + printReductionVarList(p, op.reductions(), op.reduction_vars()); } if (op.inclusive()) { - p << " inclusive"; + p << "inclusive "; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false); @@ -921,42 +1012,7 @@ } static LogicalResult verifyWsLoopOp(WsLoopOp op) { - if (op.getNumReductionVars() != 0) { - if (!op.reductions() || - op.reductions()->size() != op.getNumReductionVars()) { - return op.emitOpError() << "expected as many reduction symbol references " - "as reduction variables"; - } - } else { - if (op.reductions()) - return op.emitOpError() << "unexpected reduction symbol references"; - return success(); - } - - DenseSet accumulators; - for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) { - Value accum = std::get<0>(args); - if (!accumulators.insert(accum).second) { - return op.emitOpError() << "accumulator variable used more than once"; - } - Type varType = accum.getType().cast(); - auto symbolRef = std::get<1>(args).cast(); - auto decl = - SymbolTable::lookupNearestSymbolFrom(op, symbolRef); - if (!decl) { - return op.emitOpError() << "expected symbol reference " << symbolRef - << " to point to a reduction declaration"; - } - - if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { - return op.emitOpError() - << "expected accumulator (" << varType - << ") to be the same type as reduction declaration (" - << decl.getAccumulatorType() << ")"; - } - } - - return success(); + return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s func @unknown_clause() { - // expected-error@+1 {{invalid is not a valid clause for the omp.parallel operation}} + // expected-error@+1 {{invalid is not a valid clause}} omp.parallel invalid { }