diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -91,6 +91,9 @@ OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); + + let parser = [{ return parseParallelOp(parser, result); }]; + let printer = [{ return printParallelOp(p, *this); }]; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -12,11 +12,15 @@ #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringSwitch.h" +#include #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" +#include "mlir/IR/OperationSupport.h" using namespace mlir; using namespace mlir::omp; @@ -29,6 +33,202 @@ >(); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +static ParseResult +parseOperandAndTypeList(OpAsmParser &parser, + SmallVectorImpl &operands, + SmallVectorImpl &types) { + if (parser.parseLParen()) { + return failure(); + } + do { + OpAsmParser::OperandType operand; + Type type; + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operands.push_back(operand); + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) { + return failure(); + } + + return success(); +} + +static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { + p << "omp.parallel"; + + if (auto ifCond = op.if_expr_var()) + p << " if(" << ifCond << ")"; + + if (auto threads = op.num_threads_var()) + p << " num_threads(" << threads << " : " << threads.getType() << ")"; + + // Print private, firstprivate, shared and copyin parameters + auto printDataVars = [&p](StringRef name, OperandRange vars) { + if (vars.size()) { + p << " " << name << "("; + for (std::size_t i = 0; i < vars.size(); ++i) { + std::string separator = i == vars.size() - 1 ? ")" : ", "; + p << vars[i] << " : " << vars[i].getType() << separator; + } + } + }; + printDataVars("private", op.private_vars()); + printDataVars("firstprivate", op.firstprivate_vars()); + printDataVars("shared", op.shared_vars()); + printDataVars("copyin", op.shared_vars()); + + if (auto def = op.default_val()) + p << " default(" << def->drop_front(3) << ")"; + + if (auto bind = op.proc_bind_val()) + p << " proc_bind(" << bind << ")"; + + p.printRegion(op.getRegion()); +} + +static ParseResult parseParallelOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType ifCond; + std::pair numThreads; + llvm::SmallVector privates; + llvm::SmallVector privateTypes; + llvm::SmallVector firstprivates; + llvm::SmallVector firstprivateTypes; + llvm::SmallVector shareds; + llvm::SmallVector sharedTypes; + llvm::SmallVector copyins; + llvm::SmallVector copyinTypes; + std::array segments{0, 0, 0, 0, 0, 0}; + llvm::StringRef keyword; + + while (succeeded(parser.parseOptionalKeyword(&keyword))) { + if (keyword == "if") { + // Fail if there was already another if condition + if (segments[0]) + return failure(); + if (parser.parseLParen() || parser.parseOperand(ifCond) || + parser.parseRParen()) + return failure(); + segments[0] = 1; + } else if (keyword == "num_threads") { + // fail if there was already another num_threads clause + if (segments[1]) + return failure(); + if (parser.parseLParen() || parser.parseOperand(numThreads.first) || + parser.parseColonType(numThreads.second) || parser.parseRParen()) + return failure(); + segments[1] = 1; + } else if (keyword == "private") { + // fail if there was already another private clause + if (segments[2]) + return failure(); + if (parseOperandAndTypeList(parser, privates, privateTypes)) + return failure(); + segments[2] = privates.size(); + } else if (keyword == "firstprivate") { + // fail if there was already another firstprivate clause + if (segments[3]) + return failure(); + if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) + return failure(); + segments[3] = firstprivates.size(); + } else if (keyword == "shared") { + // fail if there was already another shared clause + if (segments[4]) + return failure(); + if (parseOperandAndTypeList(parser, shareds, sharedTypes)) + return failure(); + segments[4] = shareds.size(); + } else if (keyword == "copyin") { + // fail if there was already another copyin clause + if (segments[5]) + return failure(); + if (parseOperandAndTypeList(parser, copyins, copyinTypes)) + return failure(); + segments[5] = copyins.size(); + } else if (keyword == "default") { + // fail if there was already another default clause + if (result.attributes.get("default_val")) + return failure(); + llvm::StringRef defval; + if (parser.parseLParen() || parser.parseKeyword(&defval) || + parser.parseRParen()) + return failure(); + llvm::SmallString<16> attrval; + // The def prefix is required for the attribute as "private" is a keyword + // in C++ + attrval += "def"; + attrval += defval; + auto attr = parser.getBuilder().getStringAttr(attrval); + result.addAttribute("default_val", attr); + } else if (keyword == "proc_bind") { + // fail if there was already another default clause + if (result.attributes.get("proc_bind")) + return failure(); + llvm::StringRef bind; + if (parser.parseLParen() || parser.parseKeyword(&bind) || + parser.parseRParen()) + return failure(); + auto attr = parser.getBuilder().getStringAttr(bind); + result.addAttribute("proc_bind_val", attr); + } else { + return failure(); + } + } + + // Add if parameter + if (segments[0]) { + parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), + result.operands); + } + + // Add num_threads parameter + if (segments[1]) { + parser.resolveOperand(numThreads.first, numThreads.second, result.operands); + } + + // Add private parameters + if (segments[2]) { + parser.resolveOperands(privates, privateTypes, privates[0].location, + result.operands); + } + + // Add firstprivate parameters + if (segments[3]) { + parser.resolveOperands(firstprivates, firstprivateTypes, + firstprivates[0].location, result.operands); + } + + // Add shared parameters + if (segments[4]) { + parser.resolveOperands(shareds, sharedTypes, shareds[0].location, + result.operands); + } + + // Add copyin parameters + if (segments[5]) { + parser.resolveOperands(copyins, copyinTypes, copyins[0].location, + result.operands); + } + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr(segments)); + + Region *body = result.addRegion(); + llvm::SmallVector regionArgs{}; + llvm::SmallVector regionArgTypes{}; + if (parser.parseRegion(*body, regionArgs, regionArgTypes)) + return failure(); + return success(); +} + namespace mlir { namespace omp { #define GET_OP_CLASSES diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt %s | mlir-opt | FileCheck %s func @omp_barrier() -> () { // CHECK: omp.barrier @@ -80,3 +80,32 @@ return } + +func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { + omp.parallel { + omp.terminator + } + + // CHECK: omp.parallel + omp.parallel num_threads(%num_threads : si32) { + omp.terminator + } + + omp.parallel private(%data_var : memref, %data_var : memref) firstprivate(%data_var : memref) { + omp.terminator + } + + omp.parallel shared(%data_var : memref) copyin(%data_var : memref, %data_var : memref) { + omp.parallel if(%if_cond) { + omp.terminator + } + omp.terminator + } + + omp.parallel if(%if_cond) num_threads(%num_threads : si32) + private(%data_var : memref) proc_bind(close) { + omp.terminator + } + + return +}