diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -91,6 +91,9 @@ OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); + + let parser = [{ return parseParallelOp(parser, result); }]; + let printer = [{ return printParallelOp(p, *this); }]; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -12,9 +12,14 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" @@ -29,6 +34,245 @@ >(); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +/// Parse a list of operands with types. +/// +/// operand-and-type-list ::= `(` ssa-id-and-type-list `)` +/// ssa-id-and-type-list ::= ssa-id-and-type | +/// ssa-id-and-type ',' ssa-id-and-type-list +/// ssa-id-and-type ::= ssa-id `:` type +static ParseResult +parseOperandAndTypeList(OpAsmParser &parser, + SmallVectorImpl &operands, + SmallVectorImpl &types) { + if (parser.parseLParen()) + return failure(); + + do { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operands.push_back(operand); + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) + return failure(); + + return success(); +} + +static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { + p << "omp.parallel"; + + if (auto ifCond = op.if_expr_var()) + p << " if(" << ifCond << ")"; + + if (auto threads = op.num_threads_var()) + p << " num_threads(" << threads << " : " << threads.getType() << ")"; + + // Print private, firstprivate, shared and copyin parameters + auto printDataVars = [&p](StringRef name, OperandRange vars) { + if (vars.size()) { + p << " " << name << "("; + for (unsigned i = 0; i < vars.size(); ++i) { + std::string separator = i == vars.size() - 1 ? ")" : ", "; + p << vars[i] << " : " << vars[i].getType() << separator; + } + } + }; + printDataVars("private", op.private_vars()); + printDataVars("firstprivate", op.firstprivate_vars()); + printDataVars("shared", op.shared_vars()); + printDataVars("copyin", op.copyin_vars()); + + if (auto def = op.default_val()) + p << " default(" << def->drop_front(3) << ")"; + + if (auto bind = op.proc_bind_val()) + p << " proc_bind(" << bind << ")"; + + p.printRegion(op.getRegion()); +} + +/// Emit an error if the same clause is present more than once on an operation. +static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, + llvm::StringRef operation) { + return parser.emitError(parser.getNameLoc()) + << " at most one " << clause << " clause can appear on the " + << operation << " operation"; +} + +/// Parses a parallel operation. +/// +/// operation ::= `omp.parallel` clause-list +/// clause-list ::= clause | clause clause-list +/// clause ::= if | numThreads | private | firstprivate | shared | copyin | +/// default | procBind +/// if ::= `if` `(` ssa-id `)` +/// numThreads ::= `num_threads` `(` ssa-id-and-type `)` +/// private ::= `private` operand-and-type-list +/// firstprivate ::= `firstprivate` operand-and-type-list +/// shared ::= `shared` operand-and-type-list +/// copyin ::= `copyin` operand-and-type-list +/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) +/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` +/// +/// Note that each clause can only appear once in the clase-list. +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType ifCond; + std::pair numThreads; + llvm::SmallVector privates; + llvm::SmallVector privateTypes; + llvm::SmallVector firstprivates; + llvm::SmallVector firstprivateTypes; + llvm::SmallVector shareds; + llvm::SmallVector sharedTypes; + llvm::SmallVector copyins; + llvm::SmallVector copyinTypes; + std::array segments{0, 0, 0, 0, 0, 0}; + llvm::StringRef keyword; + bool defaultVal = false; + bool procBind = false; + + const int ifClausePos = 0; + const int numThreadsClausePos = 1; + const int privateClausePos = 2; + const int firstprivateClausePos = 3; + const int sharedClausePos = 4; + const int copyinClausePos = 5; + const llvm::StringRef opName = result.name.getStringRef(); + + while (succeeded(parser.parseOptionalKeyword(&keyword))) { + if (keyword == "if") { + // Fail if there was already another if condition + if (segments[ifClausePos]) + return allowedOnce(parser, "if", opName); + if (parser.parseLParen() || parser.parseOperand(ifCond) || + parser.parseRParen()) + return failure(); + segments[ifClausePos] = 1; + } else if (keyword == "num_threads") { + // fail if there was already another num_threads clause + if (segments[numThreadsClausePos]) + return allowedOnce(parser, "num_threads", opName); + if (parser.parseLParen() || parser.parseOperand(numThreads.first) || + parser.parseColonType(numThreads.second) || parser.parseRParen()) + return failure(); + segments[numThreadsClausePos] = 1; + } else if (keyword == "private") { + // fail if there was already another private clause + if (segments[privateClausePos]) + return allowedOnce(parser, "private", opName); + if (parseOperandAndTypeList(parser, privates, privateTypes)) + return failure(); + segments[privateClausePos] = privates.size(); + } else if (keyword == "firstprivate") { + // fail if there was already another firstprivate clause + if (segments[firstprivateClausePos]) + return allowedOnce(parser, "firstprivate", opName); + if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + return failure(); + segments[firstprivateClausePos] = firstprivates.size(); + } else if (keyword == "shared") { + // fail if there was already another shared clause + if (segments[sharedClausePos]) + return allowedOnce(parser, "shared", opName); + if (parseOperandAndTypeList(parser, shareds, sharedTypes)) + return failure(); + segments[sharedClausePos] = shareds.size(); + } else if (keyword == "copyin") { + // fail if there was already another copyin clause + if (segments[copyinClausePos]) + return allowedOnce(parser, "copyin", opName); + if (parseOperandAndTypeList(parser, copyins, copyinTypes)) + return failure(); + segments[copyinClausePos] = copyins.size(); + } else if (keyword == "default") { + // fail if there was already another default clause + if (defaultVal) + return allowedOnce(parser, "default", opName); + defaultVal = true; + llvm::StringRef defval; + if (parser.parseLParen() || parser.parseKeyword(&defval) || + parser.parseRParen()) + return failure(); + llvm::SmallString<16> attrval; + // The def prefix is required for the attribute as "private" is a keyword + // in C++ + attrval += "def"; + attrval += defval; + auto attr = parser.getBuilder().getStringAttr(attrval); + result.addAttribute("default_val", attr); + } else if (keyword == "proc_bind") { + // fail if there was already another default clause + if (procBind) + return allowedOnce(parser, "proc_bind", opName); + procBind = true; + llvm::StringRef bind; + if (parser.parseLParen() || parser.parseKeyword(&bind) || + parser.parseRParen()) + return failure(); + auto attr = parser.getBuilder().getStringAttr(bind); + result.addAttribute("proc_bind_val", attr); + } else { + return parser.emitError(parser.getNameLoc()) + << keyword << " is not a valid clause for the " << opName + << " operation"; + } + } + + // Add if parameter + if (segments[ifClausePos]) { + parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), + result.operands); + } + + // Add num_threads parameter + if (segments[numThreadsClausePos]) { + parser.resolveOperand(numThreads.first, numThreads.second, result.operands); + } + + // Add private parameters + if (segments[privateClausePos]) { + parser.resolveOperands(privates, privateTypes, privates[0].location, + result.operands); + } + + // Add firstprivate parameters + if (segments[firstprivateClausePos]) { + parser.resolveOperands(firstprivates, firstprivateTypes, + firstprivates[0].location, result.operands); + } + + // Add shared parameters + if (segments[sharedClausePos]) { + parser.resolveOperands(shareds, sharedTypes, shareds[0].location, + result.operands); + } + + // Add copyin parameters + if (segments[copyinClausePos]) { + parser.resolveOperands(copyins, copyinTypes, copyins[0].location, + result.operands); + } + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + llvm::SmallVector regionArgs; + llvm::SmallVector regionArgTypes; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + namespace mlir { namespace omp { #define GET_OP_CLASSES diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func @unknown_clause() { + // expected-error@+1 {{invalid is not a valid clause for the omp.parallel operation}} + omp.parallel invalid { + } + + return +} + +// ----- + +func @if_once(%n : i1) { + // expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}} + omp.parallel if(%n) if(%n) { + } + + return +} + +// ----- + +func @num_threads_once(%n : si32) { + // expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}} + omp.parallel num_threads(%n : si32) num_threads(%n : si32) { + } + + return +} + +// ----- + +func @private_once(%n : memref) { + // expected-error@+1 {{at most one private clause can appear on the omp.parallel operation}} + omp.parallel private(%n : memref) private(%n : memref) { + } + + return +} + +// ----- + +func @firstprivate_once(%n : memref) { + // expected-error@+1 {{at most one firstprivate clause can appear on the omp.parallel operation}} + omp.parallel firstprivate(%n : memref) firstprivate(%n : memref) { + } + + return +} + +// ----- + +func @shared_once(%n : memref) { + // expected-error@+1 {{at most one shared clause can appear on the omp.parallel operation}} + omp.parallel shared(%n : memref) shared(%n : memref) { + } + + return +} + +// ----- + +func @copyin_once(%n : memref) { + // expected-error@+1 {{at most one copyin clause can appear on the omp.parallel operation}} + omp.parallel copyin(%n : memref) copyin(%n : memref) { + } + + return +} + +// ----- + +func @default_once() { + // expected-error@+1 {{at most one default clause can appear on the omp.parallel operation}} + omp.parallel default(private) default(firstprivate) { + } + + return +} + +// ----- + +func @proc_bind_once() { + // expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}} + omp.parallel proc_bind(close) proc_bind(spread) { + } + + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s func @omp_barrier() -> () { // CHECK: omp.barrier @@ -51,11 +51,11 @@ } func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { - // CHECK: omp.parallel + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var) ({ // test without if condition - // CHECK: omp.parallel + // CHECK: omp.parallel num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[0,1,1,1,1,1]>: vector<6xi32>, default_val = "defshared"} : (si32, memref, memref, memref, memref) -> () @@ -64,7 +64,7 @@ omp.barrier // test without num_threads - // CHECK: omp.parallel + // CHECK: omp.parallel if(%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel"(%if_cond, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[1,0,1,1,1,1]> : vector<6xi32>} : (i1, memref, memref, memref, memref) -> () @@ -73,10 +73,43 @@ }) {operand_segment_sizes = dense<[1,1,1,1,1,1]> : vector<6xi32>, proc_bind_val = "spread"} : (i1, si32, memref, memref, memref, memref) -> () // test with multiple parameters for single variadic argument - // CHECK: omp.parallel + // CHECK: omp.parallel private(%{{.*}} : memref) firstprivate(%{{.*}} : memref, %{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel" (%data_var, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[0,0,1,2,1,1]> : vector<6xi32>} : (memref, memref, memref, memref, memref) -> () return } + +func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { + // CHECK: omp.parallel + omp.parallel { + omp.terminator + } + + // CHECK: omp.parallel num_threads(%{{.*}} : si32) + omp.parallel num_threads(%num_threads : si32) { + omp.terminator + } + + // CHECK: omp.parallel private(%{{.*}} : memref, %{{.*}} : memref) firstprivate(%{{.*}} : memref) + omp.parallel private(%data_var : memref, %data_var : memref) firstprivate(%data_var : memref) { + omp.terminator + } + + // CHECK omp.parallel shared(%{{.*}} : memref) copyin(%{{.*}} : memref, %{{.*}} : memref) + omp.parallel shared(%data_var : memref) copyin(%data_var : memref, %data_var : memref) { + omp.parallel if(%if_cond) { + omp.terminator + } + omp.terminator + } + + // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) proc_bind(close) + omp.parallel num_threads(%num_threads : si32) if(%if_cond) + private(%data_var : memref) proc_bind(close) { + omp.terminator + } + + return +}