diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -91,6 +91,9 @@ OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); + + let parser = [{ return parseParallelOp(parser, result); }]; + let printer = [{ return printParallelOp(p, *this); }]; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -12,11 +12,16 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" +#include "mlir/IR/OperationSupport.h" using namespace mlir; using namespace mlir::omp; @@ -29,6 +34,213 @@ >(); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +static ParseResult +parseOperandAndTypeList(OpAsmParser &parser, + SmallVectorImpl &operands, + SmallVectorImpl &types) { + if (parser.parseLParen()) + return failure(); + + do { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operands.push_back(operand); + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) + return failure(); + + return success(); +} + +static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { + p << "omp.parallel"; + + if (auto ifCond = op.if_expr_var()) + p << " if(" << ifCond << ")"; + + if (auto threads = op.num_threads_var()) + p << " num_threads(" << threads << " : " << threads.getType() << ")"; + + // Print private, firstprivate, shared and copyin parameters + auto printDataVars = [&p](StringRef name, OperandRange vars) { + if (vars.size()) { + p << " " << name << "("; + for (std::size_t i = 0; i < vars.size(); ++i) { + std::string separator = i == vars.size() - 1 ? ")" : ", "; + p << vars[i] << " : " << vars[i].getType() << separator; + } + } + }; + printDataVars("private", op.private_vars()); + printDataVars("firstprivate", op.firstprivate_vars()); + printDataVars("shared", op.shared_vars()); + printDataVars("copyin", op.copyin_vars()); + + if (auto def = op.default_val()) + p << " default(" << def->drop_front(3) << ")"; + + if (auto bind = op.proc_bind_val()) + p << " proc_bind(" << bind << ")"; + + p.printRegion(op.getRegion()); +} + +static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, + llvm::StringRef operation) { + return parser.emitError(parser.getNameLoc()) + << " at most one " << clause << " clause can appear on the " + << operation << " operation"; +} + +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType ifCond; + std::pair numThreads; + llvm::SmallVector privates; + llvm::SmallVector privateTypes; + llvm::SmallVector firstprivates; + llvm::SmallVector firstprivateTypes; + llvm::SmallVector shareds; + llvm::SmallVector sharedTypes; + llvm::SmallVector copyins; + llvm::SmallVector copyinTypes; + std::array segments{0, 0, 0, 0, 0, 0}; + llvm::StringRef keyword; + bool defaultVal = false; + bool procBind = false; + + while (succeeded(parser.parseOptionalKeyword(&keyword))) { + if (keyword == "if") { + // Fail if there was already another if condition + if (segments[0]) + return allowedOnce(parser, "if", "parallel"); + if (parser.parseLParen() || parser.parseOperand(ifCond) || + parser.parseRParen()) + return failure(); + segments[0] = 1; + } else if (keyword == "num_threads") { + // fail if there was already another num_threads clause + if (segments[1]) + return allowedOnce(parser, "num_threads", "parallel"); + if (parser.parseLParen() || parser.parseOperand(numThreads.first) || + parser.parseColonType(numThreads.second) || parser.parseRParen()) + return failure(); + segments[1] = 1; + } else if (keyword == "private") { + // fail if there was already another private clause + if (segments[2]) + return allowedOnce(parser, "private", "parallel"); + if (parseOperandAndTypeList(parser, privates, privateTypes)) + return failure(); + segments[2] = privates.size(); + } else if (keyword == "firstprivate") { + // fail if there was already another firstprivate clause + if (segments[3]) + return allowedOnce(parser, "firstprivate", "parallel"); + if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + return failure(); + segments[3] = firstprivates.size(); + } else if (keyword == "shared") { + // fail if there was already another shared clause + if (segments[4]) + return allowedOnce(parser, "shared", "parallel"); + if (parseOperandAndTypeList(parser, shareds, sharedTypes)) + return failure(); + segments[4] = shareds.size(); + } else if (keyword == "copyin") { + // fail if there was already another copyin clause + if (segments[5]) + return allowedOnce(parser, "copyin", "parallel"); + if (parseOperandAndTypeList(parser, copyins, copyinTypes)) + return failure(); + segments[5] = copyins.size(); + } else if (keyword == "default") { + // fail if there was already another default clause + if (defaultVal) + return allowedOnce(parser, "default", "parallel"); + defaultVal = true; + llvm::StringRef defval; + if (parser.parseLParen() || parser.parseKeyword(&defval) || + parser.parseRParen()) + return failure(); + llvm::SmallString<16> attrval; + // The def prefix is required for the attribute as "private" is a keyword + // in C++ + attrval += "def"; + attrval += defval; + auto attr = parser.getBuilder().getStringAttr(attrval); + result.addAttribute("default_val", attr); + } else if (keyword == "proc_bind") { + // fail if there was already another default clause + if (procBind) + return allowedOnce(parser, "proc_bind", "parallel"); + procBind = true; + llvm::StringRef bind; + if (parser.parseLParen() || parser.parseKeyword(&bind) || + parser.parseRParen()) + return failure(); + auto attr = parser.getBuilder().getStringAttr(bind); + result.addAttribute("proc_bind_val", attr); + } else { + return parser.emitError(parser.getNameLoc()) + << keyword << " is not a valid clause for the parallel operation"; + } + } + + // Add if parameter + if (segments[0]) { + parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), + result.operands); + } + + // Add num_threads parameter + if (segments[1]) { + parser.resolveOperand(numThreads.first, numThreads.second, result.operands); + } + + // Add private parameters + if (segments[2]) { + parser.resolveOperands(privates, privateTypes, privates[0].location, + result.operands); + } + + // Add firstprivate parameters + if (segments[3]) { + parser.resolveOperands(firstprivates, firstprivateTypes, + firstprivates[0].location, result.operands); + } + + // Add shared parameters + if (segments[4]) { + parser.resolveOperands(shareds, sharedTypes, shareds[0].location, + result.operands); + } + + // Add copyin parameters + if (segments[5]) { + parser.resolveOperands(copyins, copyinTypes, copyins[0].location, + result.operands); + } + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + llvm::SmallVector regionArgs{}; + llvm::SmallVector regionArgTypes{}; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + namespace mlir { namespace omp { #define GET_OP_CLASSES diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func @unknown_clause() { + // expected-error@+1 {{invalid is not a valid clause for the parallel operation}} + omp.parallel invalid { + } + + return +} + +// ----- + +func @if_once(%n : i1) { + // expected-error@+1 {{at most one if clause can appear on the parallel operation}} + omp.parallel if(%n) if(%n) { + } + + return +} + +// ----- + +func @num_threads_once(%n : si32) { + // expected-error@+1 {{at most one num_threads clause can appear on the parallel operation}} + omp.parallel num_threads(%n : si32) num_threads(%n : si32) { + } + + return +} + +// ----- + +func @private_once(%n : memref) { + // expected-error@+1 {{at most one private clause can appear on the parallel operation}} + omp.parallel private(%n : memref) private(%n : memref) { + } + + return +} + +// ----- + +func @firstprivate_once(%n : memref) { + // expected-error@+1 {{at most one firstprivate clause can appear on the parallel operation}} + omp.parallel firstprivate(%n : memref) firstprivate(%n : memref) { + } + + return +} + +// ----- + +func @shared_once(%n : memref) { + // expected-error@+1 {{at most one shared clause can appear on the parallel operation}} + omp.parallel shared(%n : memref) shared(%n : memref) { + } + + return +} + +// ----- + +func @copyin_once(%n : memref) { + // expected-error@+1 {{at most one copyin clause can appear on the parallel operation}} + omp.parallel copyin(%n : memref) copyin(%n : memref) { + } + + return +} + +// ----- + +func @default_once() { + // expected-error@+1 {{at most one default clause can appear on the parallel operation}} + omp.parallel default(private) default(firstprivate) { + } + + return +} + +// ----- + +func @proc_bind_once() { + // expected-error@+1 {{at most one proc_bind clause can appear on the parallel operation}} + omp.parallel proc_bind(close) proc_bind(spread) { + } + + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s func @omp_barrier() -> () { // CHECK: omp.barrier @@ -51,11 +51,11 @@ } func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { - // CHECK: omp.parallel + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var) ({ // test without if condition - // CHECK: omp.parallel + // CHECK: omp.parallel num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[0,1,1,1,1,1]>: vector<6xi32>, default_val = "defshared"} : (si32, memref, memref, memref, memref) -> () @@ -64,7 +64,7 @@ omp.barrier // test without num_threads - // CHECK: omp.parallel + // CHECK: omp.parallel if(%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel"(%if_cond, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[1,0,1,1,1,1]> : vector<6xi32>} : (i1, memref, memref, memref, memref) -> () @@ -73,10 +73,43 @@ }) {operand_segment_sizes = dense<[1,1,1,1,1,1]> : vector<6xi32>, proc_bind_val = "spread"} : (i1, si32, memref, memref, memref, memref) -> () // test with multiple parameters for single variadic argument - // CHECK: omp.parallel + // CHECK: omp.parallel private(%{{.*}} : memref) firstprivate(%{{.*}} : memref, %{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) "omp.parallel" (%data_var, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator }) {operand_segment_sizes = dense<[0,0,1,2,1,1]> : vector<6xi32>} : (memref, memref, memref, memref, memref) -> () return } + +func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { + // CHECK: omp.parallel + omp.parallel { + omp.terminator + } + + // CHECK: omp.parallel num_threads(%{{.*}} : si32) + omp.parallel num_threads(%num_threads : si32) { + omp.terminator + } + + // CHECK: omp.parallel private(%{{.*}} : memref, %{{.*}} : memref) firstprivate(%{{.*}} : memref) + omp.parallel private(%data_var : memref, %data_var : memref) firstprivate(%data_var : memref) { + omp.terminator + } + + // CHECK omp.parallel shared(%{{.*}} : memref) copyin(%{{.*}} : memref, %{{.*}} : memref) + omp.parallel shared(%data_var : memref) copyin(%data_var : memref, %data_var : memref) { + omp.parallel if(%if_cond) { + omp.terminator + } + omp.terminator + } + + // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) proc_bind(close) + omp.parallel num_threads(%num_threads : si32) if(%if_cond) + private(%data_var : memref) proc_bind(close) { + omp.terminator + } + + return +}