diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -65,6 +65,9 @@ are a variadic list of variables that specify the data sharing attribute of those variables. + The $allocators_vars and $allocate_vars parameters are a variadic list of variables + that specify the memory allocator to be used to obtain storage for private variables + The optional $proc_bind_val attribute controls the thread affinity for the execution of the parallel region. }]; @@ -76,6 +79,8 @@ Variadic:$firstprivate_vars, Variadic:$shared_vars, Variadic:$copyin_vars, + Variadic:$allocators_vars, + Variadic:$allocate_vars, OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -65,6 +65,45 @@ return success(); } +/// Parse an allocate clause with allocators and a list of operands with types. +/// +/// operand-and-type ::= `(` ssa-id-and-type `)` '->' +/// operand-and-type-list ::= `(` ssa-id-and-type-list `)` +/// ssa-id-and-type-list ::= ssa-id-and-type | +/// ssa-id-and-type ',' ssa-id-and-type-list +/// ssa-id-and-type ::= ssa-id `:` type +static ParseResult parseAllocateAndAllocator( + OpAsmParser &parser, + SmallVectorImpl &operandsAllocate, + SmallVectorImpl &typesAllocate, + SmallVectorImpl &operandsAllocator, + SmallVectorImpl &typesAllocator) { + if (parser.parseLParen()) + return failure(); + + do { + OpAsmParser::OperandType operand; + Type type; + + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + operandsAllocator.push_back(operand); + typesAllocator.push_back(type); + if (parser.parseArrow()) + return failure(); + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + + operandsAllocate.push_back(operand); + typesAllocate.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) + return failure(); + + return success(); +} + static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { p << "omp.parallel"; @@ -84,10 +123,27 @@ } } }; + + // Print allocator and allocate parameters + auto printAllocateAndAllocator = [&p](OperandRange varsAllocator, + OperandRange varsAllocate) { + if (varsAllocate.size()) { + p << " " + << "allocate" + << "("; + for (unsigned i = 0; i < varsAllocate.size(); ++i) { + std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; + p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; + p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; + } + } + }; + printDataVars("private", op.private_vars()); printDataVars("firstprivate", op.firstprivate_vars()); printDataVars("shared", op.shared_vars()); printDataVars("copyin", op.copyin_vars()); + printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); if (auto def = op.default_val()) p << " default(" << def->drop_front(3) << ")"; @@ -118,6 +174,7 @@ /// firstprivate ::= `firstprivate` operand-and-type-list /// shared ::= `shared` operand-and-type-list /// copyin ::= `copyin` operand-and-type-list +/// allocate ::= `allocate` operand-and-type `->` operand-and-type-list /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` /// @@ -134,7 +191,11 @@ llvm::SmallVector sharedTypes; llvm::SmallVector copyins; llvm::SmallVector copyinTypes; - std::array segments{0, 0, 0, 0, 0, 0}; + llvm::SmallVector allocates; + llvm::SmallVector allocateTypes; + llvm::SmallVector allocators; + llvm::SmallVector allocatorTypes; + std::array segments{0, 0, 0, 0, 0, 0, 0, 0}; llvm::StringRef keyword; bool defaultVal = false; bool procBind = false; @@ -145,6 +206,8 @@ const int firstprivateClausePos = 3; const int sharedClausePos = 4; const int copyinClausePos = 5; + const int allocatorPos = 6; + const int allocateClausePos = 7; const llvm::StringRef opName = result.name.getStringRef(); while (succeeded(parser.parseOptionalKeyword(&keyword))) { @@ -192,6 +255,15 @@ if (parseOperandAndTypeList(parser, copyins, copyinTypes)) return failure(); segments[copyinClausePos] = copyins.size(); + } else if (keyword == "allocate") { + // fail if there was already another allocate clause + if (segments[allocateClausePos]) + return allowedOnce(parser, "allocate", opName); + if (parseAllocateAndAllocator(parser, allocates, allocateTypes, + allocators, allocatorTypes)) + return failure(); + segments[allocateClausePos] = allocates.size(); + segments[allocatorPos] = allocators.size(); } else if (keyword == "default") { // fail if there was already another default clause if (defaultVal) @@ -260,6 +332,18 @@ result.operands); } + // Add allocate parameters + if (segments[allocateClausePos]) { + parser.resolveOperands(allocates, allocateTypes, allocates[0].location, + result.operands); + } + + // Add allocator parameters + if (segments[allocatorPos]) { + parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, + result.operands); + } + result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr(segments)); diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -51,37 +51,43 @@ } func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { - // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) - "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var) ({ + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) allocate(%{{.*}} : memref -> %{{.*}} : memref) + "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var, %data_var, %data_var) ({ // test without if condition - // CHECK: omp.parallel num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) - "omp.parallel"(%num_threads, %data_var, %data_var, %data_var, %data_var) ({ + // CHECK: omp.parallel num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) allocate(%{{.*}} : memref -> %{{.*}} : memref) + "omp.parallel"(%num_threads, %data_var, %data_var, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,1,1,1,1,1]>: vector<6xi32>, default_val = "defshared"} : (si32, memref, memref, memref, memref) -> () + }) {operand_segment_sizes = dense<[0,1,1,1,1,1,1,1]>: vector<8xi32>, default_val = "defshared"} : (si32, memref, memref, memref, memref, memref, memref) -> () // CHECK: omp.barrier omp.barrier // test without num_threads - // CHECK: omp.parallel if(%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) - "omp.parallel"(%if_cond, %data_var, %data_var, %data_var, %data_var) ({ + // CHECK: omp.parallel if(%{{.*}}) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) allocate(%{{.*}} : memref -> %{{.*}} : memref) + "omp.parallel"(%if_cond, %data_var, %data_var, %data_var, %data_var, %data_var, %data_var) ({ + omp.terminator + }) {operand_segment_sizes = dense<[1,0,1,1,1,1,1,1]> : vector<8xi32>} : (i1, memref, memref, memref, memref, memref, memref) -> () + + // test without allocate + // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref) firstprivate(%{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) + "omp.parallel"(%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,0,1,1,1,1]> : vector<6xi32>} : (i1, memref, memref, memref, memref) -> () + }) {operand_segment_sizes = dense<[1,1,1,1,1,1,0,0]> : vector<8xi32>} : (i1, si32, memref, memref, memref, memref) -> () omp.terminator - }) {operand_segment_sizes = dense<[1,1,1,1,1,1]> : vector<6xi32>, proc_bind_val = "spread"} : (i1, si32, memref, memref, memref, memref) -> () + }) {operand_segment_sizes = dense<[1,1,1,1,1,1,1,1]> : vector<8xi32>, proc_bind_val = "spread"} : (i1, si32, memref, memref, memref, memref, memref, memref) -> () // test with multiple parameters for single variadic argument - // CHECK: omp.parallel private(%{{.*}} : memref) firstprivate(%{{.*}} : memref, %{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) - "omp.parallel" (%data_var, %data_var, %data_var, %data_var, %data_var) ({ + // CHECK: omp.parallel private(%{{.*}} : memref) firstprivate(%{{.*}} : memref, %{{.*}} : memref) shared(%{{.*}} : memref) copyin(%{{.*}} : memref) allocate(%{{.*}} : memref -> %{{.*}} : memref) + "omp.parallel" (%data_var, %data_var, %data_var, %data_var, %data_var, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,0,1,2,1,1]> : vector<6xi32>} : (memref, memref, memref, memref, memref) -> () + }) {operand_segment_sizes = dense<[0,0,1,2,1,1,1,1]> : vector<8xi32>} : (memref, memref, memref, memref, memref, memref, memref) -> () return } -func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32) -> () { +func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () { // CHECK: omp.parallel omp.parallel { omp.terminator @@ -92,6 +98,11 @@ omp.terminator } + // CHECK: omp.parallel allocate(%{{.*}} : memref -> %{{.*}} : memref) + omp.parallel allocate(%data_var : memref -> %data_var : memref) { + omp.terminator + } + // CHECK: omp.parallel private(%{{.*}} : memref, %{{.*}} : memref) firstprivate(%{{.*}} : memref) omp.parallel private(%data_var : memref, %data_var : memref) firstprivate(%data_var : memref) { omp.terminator