diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -265,11 +265,13 @@ let isValueList = true; } -def OMP_ORDER_concurrent : ClauseVal<"default",2,0> { let isDefault = 1; } +def OMP_ORDER_concurrent : ClauseVal<"concurrent",1,1> {} +def OMP_ORDER_default : ClauseVal<"default",2,0> { let isDefault = 1; } def OMPC_Order : Clause<"order"> { let clangClass = "OMPOrderClause"; let enumClauseValue = "OrderKind"; let allowedClauseValues = [ + OMP_ORDER_default, OMP_ORDER_concurrent ]; } diff --git a/llvm/unittests/Frontend/OpenMPParsingTest.cpp b/llvm/unittests/Frontend/OpenMPParsingTest.cpp --- a/llvm/unittests/Frontend/OpenMPParsingTest.cpp +++ b/llvm/unittests/Frontend/OpenMPParsingTest.cpp @@ -55,8 +55,9 @@ } TEST(OpenMPParsingTest, getOrderKind) { - EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_concurrent); - EXPECT_EQ(getOrderKind("default"), OMP_ORDER_concurrent); + EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_default); + EXPECT_EQ(getOrderKind("default"), OMP_ORDER_default); + EXPECT_EQ(getOrderKind("concurrent"), OMP_ORDER_concurrent); } TEST(OpenMPParsingTest, getProcBindKind) { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -491,7 +491,6 @@ collapseClause, orderClause, orderedClause, - inclusiveClause, COUNT }; @@ -574,8 +573,7 @@ // segments if (clause == defaultClause || clause == procBindClause || clause == nowaitClause || clause == collapseClause || - clause == orderClause || clause == orderedClause || - clause == inclusiveClause) + clause == orderClause || clause == orderedClause) continue; pos[clause] = currPos++; @@ -593,7 +591,7 @@ bool allowRepeat = false) -> ParseResult { if (!llvm::is_contained(clauses, clause)) return parser.emitError(parser.getCurrentLocation()) - << clauseKeyword << "is not a valid clause for the " << opName + << clauseKeyword << " is not a valid clause for the " << opName << " operation"; if (done[clause] && !allowRepeat) return parser.emitError(parser.getCurrentLocation()) @@ -714,12 +712,7 @@ parser.parseKeyword(&order) || parser.parseRParen()) return failure(); auto attr = parser.getBuilder().getStringAttr(order); - result.addAttribute("order", attr); - } else if (clauseKeyword == "inclusive") { - if (checkAllowed(inclusiveClause)) - return failure(); - auto attr = UnitAttr::get(parser.getBuilder().getContext()); - result.addAttribute("inclusive", attr); + result.addAttribute("order_val", attr); } else { return parser.emitError(parser.getNameLoc()) << clauseKeyword << " is not a valid clause"; @@ -859,11 +852,11 @@ /// /// wsloop ::= `omp.wsloop` loop-control clause-list /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds -/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps +/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps /// steps := `step` `(`ssa-id-list`)` /// clause-list ::= clause clause-list | empty /// clause ::= private | firstprivate | lastprivate | linear | schedule | -// collapse | nowait | ordered | order | inclusive | reduction +// collapse | nowait | ordered | order | reduction static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { // Parse an opening `(` followed by induction variables followed by `)` @@ -890,6 +883,11 @@ parser.resolveOperands(upper, loopVarType, result.operands)) return failure(); + if (succeeded(parser.parseOptionalKeyword("inclusive"))) { + auto attr = UnitAttr::get(parser.getBuilder().getContext()); + result.addAttribute("inclusive", attr); + } + // Parse step values. SmallVector steps; if (parser.parseKeyword("step") || @@ -920,7 +918,11 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { auto args = op.getRegion().front().getArguments(); p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() - << ") to (" << op.upperBound() << ") step (" << op.step() << ") "; + << ") to (" << op.upperBound() << ") "; + if (op.inclusive()) { + p << "inclusive "; + } + p << "step (" << op.step() << ") "; printDataVars(p, op.private_vars(), "private"); printDataVars(p, op.firstprivate_vars(), "firstprivate"); @@ -946,15 +948,14 @@ if (auto ordered = op.ordered_val()) p << "ordered(" << ordered << ") "; + if (auto order = op.order_val()) + p << "order(" << order << ") "; + if (!op.reduction_vars().empty()) { p << "reduction("; printReductionVarList(p, op.reductions(), op.reduction_vars()); } - if (op.inclusive()) { - p << "inclusive "; - } - p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -68,6 +68,61 @@ return } +// ----- + +func @lastprivate_not_allowed(%n : memref) { + // expected-error@+1 {{lastprivate is not a valid clause for the omp.parallel operation}} + omp.parallel lastprivate(%n : memref) {} + return +} + +// ----- + +func @nowait_not_allowed(%n : memref) { + // expected-error@+1 {{nowait is not a valid clause for the omp.parallel operation}} + omp.parallel nowait {} + return +} + +// ----- + +func @linear_not_allowed(%data_var : memref, %linear_var : i32) { + // expected-error@+1 {{linear is not a valid clause for the omp.parallel operation}} + omp.parallel linear(%data_var = %linear_var : memref) {} + return +} + +// ----- + +func @schedule_not_allowed() { + // expected-error@+1 {{schedule is not a valid clause for the omp.parallel operation}} + omp.parallel schedule(static) {} + return +} + +// ----- + +func @collapse_not_allowed() { + // expected-error@+1 {{collapse is not a valid clause for the omp.parallel operation}} + omp.parallel collapse(3) {} + return +} + +// ----- + +func @order_not_allowed() { + // expected-error@+1 {{order is not a valid clause for the omp.parallel operation}} + omp.parallel order(concurrent) {} + return +} + +// ----- + +func @ordered_not_allowed() { + // expected-error@+1 {{ordered is not a valid clause for the omp.parallel operation}} + omp.parallel ordered(2) {} +} + // ----- func @default_once() { @@ -90,6 +145,78 @@ // ----- +func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) { + // expected-error @below {{inclusive is not a valid clause}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait inclusive { + omp.yield + } +} + +// ----- + +func @order_value(%lb : index, %ub : index, %step : index) { + // expected-error @below {{attribute 'order_val' failed to satisfy constraint: OrderKind Clause}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) { + omp.yield + } +} + +// ----- + +func @shared_not_allowed(%lb : index, %ub : index, %step : index, %var : memref) { + // expected-error @below {{shared is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) shared(%var) { + omp.yield + } +} + +// ----- + +func @copyin(%lb : index, %ub : index, %step : index, %var : memref) { + // expected-error @below {{copyin is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) copyin(%var) { + omp.yield + } +} + +// ----- + +func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) { + // expected-error @below {{if is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) { + omp.yield + } +} + +// ----- + +func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) { + // expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) { + omp.yield + } +} + +// ----- + +func @default_not_allowed(%lb : index, %ub : index, %step : index) { + // expected-error @below {{default is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) default(private) { + omp.yield + } +} + +// ----- + +func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) { + // expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}} + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) { + omp.yield + } +} + +// ----- + // expected-error @below {{op expects initializer region with one argument of the reduction type}} omp.reduction.declare @add_f32 : f64 init { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -123,7 +123,27 @@ omp.terminator } - return + // CHECK: omp.parallel default(private) + omp.parallel default(private) { + omp.terminator + } + + // CHECK: omp.parallel default(firstprivate) + omp.parallel default(firstprivate) { + omp.terminator + } + + // CHECK: omp.parallel default(shared) + omp.parallel default(shared) { + omp.terminator + } + + // CHECK: omp.parallel default(none) + omp.parallel default(none) { + omp.terminator + } + + return } // CHECK-LABEL: omp_wsloop @@ -207,6 +227,21 @@ omp.yield } + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) + omp.wsloop (%iv) : index = (%lb) to (%ub) inclusive step (%step) { + omp.yield + } + + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait { + omp.yield + } + + // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait order(concurrent) + omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(concurrent) nowait { + omp.yield + } + return }