diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -107,6 +107,8 @@ void Enter(const parser::OpenMPDeclareTargetConstruct &); void Leave(const parser::OpenMPDeclareTargetConstruct &); + void Enter(const parser::OpenMPAtomicConstruct &); + void Leave(const parser::OpenMPAtomicConstruct &); void Enter(const parser::OpenMPSimpleStandaloneConstruct &); void Leave(const parser::OpenMPSimpleStandaloneConstruct &); void Enter(const parser::OpenMPFlushConstruct &); @@ -140,6 +142,7 @@ void Enter(const parser::OmpClause::Priority &); void Enter(const parser::OmpClause::Private &); void Enter(const parser::OmpClause::Safelen &); + void Enter(const parser::OmpClause::SeqCst &); void Enter(const parser::OmpClause::Shared &); void Enter(const parser::OmpClause::Simdlen &); void Enter(const parser::OmpClause::ThreadLimit &); diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -90,6 +90,17 @@ } } +void OmpStructureChecker::Enter(const parser::OpenMPAtomicConstruct &x) { + if (const auto &atomicRead{std::get(x.u)) { + const auto &dir{std::get(atomicRead.t)}; + PushContextAndClauseSets(dir.source, llvm::omp::OMPD_atomic); + } +} + +void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) { + dirContext_.pop_back(); +} + void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { const auto &beginBlockDir{std::get(x.t)}; const auto &endBlockDir{std::get(x.t)}; diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -440,7 +440,9 @@ VersionedClause, VersionedClause, VersionedClause, - VersionedClause, + VersionedClause + ]; + let allowedOnceClauses = [ VersionedClause, VersionedClause, VersionedClause, diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -81,11 +81,11 @@ ``` }]; - let arguments = (ins Optional:$async, - Variadic:$waitOperands, - Optional:$numGangs, - Optional:$numWorkers, - Optional:$vectorLength, + let arguments = (ins Optional>:$async, + Variadic>:$waitOperands, + Optional>:$numGangs, + Optional>:$numWorkers, + Optional>:$vectorLength, Optional:$ifCond, Optional:$selfCond, OptionalAttr:$reductionOp, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -101,6 +101,24 @@ return success(); } +static ParseResult parseOptionalOperandAndType(OpAsmParser &parser, + StringRef keyword, + bool &hasOptional, + OperationState &result) { + OpAsmParser::OperandType operand; + Type type; + hasOptional = false; + if (succeeded(parser.parseOptionalKeyword(keyword))) { + hasOptional = true; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseColonType(type) || + parser.resolveOperand(operand, type, result.operands) || + parser.parseRParen()) + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -133,17 +151,14 @@ copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands, attachOperands, waitOperands, reductionOperands; SmallVector operandTypes; - OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond, - selfCond; + OpAsmParser::OperandType ifCond, selfCond; bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false; bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false; - - Type indexType = builder.getIndexType(); Type i1Type = builder.getI1Type(); // async()? - if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async, - indexType, hasAsync, result))) + if (failed(parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), + hasAsync, result))) return failure(); // wait()? @@ -152,20 +167,19 @@ return failure(); // num_gangs(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(), - numGangs, indexType, hasNumGangs, result))) + if (failed(parseOptionalOperandAndType( + parser, ParallelOp::getNumGangsKeyword(), hasNumGangs, result))) return failure(); // num_workers(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(), - numWorkers, indexType, hasNumWorkers, - result))) + if (failed(parseOptionalOperandAndType( + parser, ParallelOp::getNumWorkersKeyword(), hasNumWorkers, result))) return failure(); // vector_length(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(), - vectorLength, indexType, hasVectorLength, - result))) + if (failed(parseOptionalOperandAndType(parser, + ParallelOp::getVectorLengthKeyword(), + hasVectorLength, result))) return failure(); // if()? @@ -270,7 +284,8 @@ // async()? if (Value async = op.async()) - printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")"; + printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " + << async.getType() << ")"; // wait()? printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); @@ -278,17 +293,17 @@ // num_gangs()? if (Value numGangs = op.numGangs()) printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs - << ")"; + << ": " << numGangs.getType() << ")"; // num_workers()? if (Value numWorkers = op.numWorkers()) printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers - << ")"; + << ": " << numWorkers.getType() << ")"; // vector_length()? if (Value vectorLength = op.vectorLength()) printer << " " << ParallelOp::getVectorLengthKeyword() << "(" - << vectorLength << ")"; + << vectorLength << ": " << vectorLength.getType() << ")"; // if()? if (Value ifCond = op.ifCond()) diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -8,8 +8,9 @@ %c0 = constant 0 : index %c10 = constant 10 : index %c1 = constant 1 : index + %async = constant 1 : i64 - acc.parallel async(%c1) { + acc.parallel async(%async: i64) { acc.loop gang vector { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { @@ -35,7 +36,8 @@ // CHECK-NEXT: %{{.*}} = constant 0 : index // CHECK-NEXT: %{{.*}} = constant 10 : index // CHECK-NEXT: %{{.*}} = constant 1 : index -// CHECK-NEXT: acc.parallel async(%{{.*}}) { +// CHECK-NEXT: [[ASYNC:%.*]] = constant 1 : i64 +// CHECK-NEXT: acc.parallel async([[ASYNC]]: i64) { // CHECK-NEXT: acc.loop gang vector { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -113,9 +115,11 @@ %lb = constant 0 : index %st = constant 1 : index %c10 = constant 10 : index + %numGangs = constant 10 : i64 + %numWorkers = constant 10 : i64 acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) { - acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) { + acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -154,8 +158,10 @@ // CHECK-NEXT: [[C0:%.*]] = constant 0 : index // CHECK-NEXT: [[C1:%.*]] = constant 1 : index // CHECK-NEXT: [[C10:%.*]] = constant 10 : index +// CHECK-NEXT: [[NUMGANG:%.*]] = constant 10 : i64 +// CHECK-NEXT: [[NUMWORKERS:%.*]] = constant 10 : i64 // CHECK-NEXT: acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) { -// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]]: i64) num_workers([[NUMWORKERS]]: i64) private([[ARG2]]: memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -267,12 +273,77 @@ func @testparallelop() -> () { - %vectorLength = constant 128 : index - acc.parallel vector_length(%vectorLength) { + %i64value = constant 1 : i64 + %i32value = constant 1 : i32 + %idxValue = constant 1 : index + acc.parallel async(%i64value: i64) { + } + acc.parallel async(%i32value: i32) { + } + acc.parallel async(%idxValue: index) { + } + acc.parallel wait(%i64value: i64) { + } + acc.parallel wait(%i32value: i32) { + } + acc.parallel wait(%idxValue: index) { + } + acc.parallel wait(%i64value: i64, %i32value: i32, %idxValue: index) { + } + acc.parallel num_gangs(%i64value: i64) { + } + acc.parallel num_gangs(%i32value: i32) { + } + acc.parallel num_gangs(%idxValue: index) { + } + acc.parallel num_workers(%i64value: i64) { + } + acc.parallel num_workers(%i32value: i32) { + } + acc.parallel num_workers(%idxValue: index) { + } + acc.parallel vector_length(%i64value: i64) { + } + acc.parallel vector_length(%i32value: i32) { + } + acc.parallel vector_length(%idxValue: index) { } return } -// CHECK: [[VECTORLENGTH:%.*]] = constant 128 : index -// CHECK-NEXT: acc.parallel vector_length([[VECTORLENGTH]]) { +// CHECK-LABEL: func @testparallelop +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.parallel async([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel async([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel async([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I64VALUE]]: i64, [[I32VALUE]]: i32, [[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[IDXVALUE]]: index) { // CHECK-NEXT: }