diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -60,6 +60,9 @@ let cppNamespace = "::mlir::acc"; } +// Type used in operation below +def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; + //===----------------------------------------------------------------------===// // 2.5.1 parallel Construct //===----------------------------------------------------------------------===// @@ -81,11 +84,11 @@ ``` }]; - let arguments = (ins Optional:$async, - Variadic:$waitOperands, - Optional:$numGangs, - Optional:$numWorkers, - Optional:$vectorLength, + let arguments = (ins Optional:$async, + Variadic:$waitOperands, + Optional:$numGangs, + Optional:$numWorkers, + Optional:$vectorLength, Optional:$ifCond, Optional:$selfCond, OptionalAttr:$reductionOp, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -101,6 +101,22 @@ return success(); } +static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, + StringRef keyword, + OperationState &result) { + OpAsmParser::OperandType operand; + Type type; + if (succeeded(parser.parseOptionalKeyword(keyword))) { + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseColonType(type) || + parser.resolveOperand(operand, type, result.operands) || + parser.parseRParen()) + return failure(); + return success(); + } + return llvm::None; +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -133,17 +149,15 @@ copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands, attachOperands, waitOperands, reductionOperands; SmallVector operandTypes; - OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond, - selfCond; - bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false; - bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false; - - Type indexType = builder.getIndexType(); + OpAsmParser::OperandType ifCond, selfCond; + bool hasIfCond = false, hasSelfCond = false; + OptionalParseResult async, numGangs, numWorkers, vectorLength; Type i1Type = builder.getI1Type(); // async()? - if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async, - indexType, hasAsync, result))) + async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), + result); + if (async.hasValue() && failed(*async)) return failure(); // wait()? @@ -152,20 +166,21 @@ return failure(); // num_gangs(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(), - numGangs, indexType, hasNumGangs, result))) + numGangs = parseOptionalOperandAndType( + parser, ParallelOp::getNumGangsKeyword(), result); + if (numGangs.hasValue() && failed(*numGangs)) return failure(); // num_workers(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(), - numWorkers, indexType, hasNumWorkers, - result))) + numWorkers = parseOptionalOperandAndType( + parser, ParallelOp::getNumWorkersKeyword(), result); + if (numWorkers.hasValue() && failed(*numWorkers)) return failure(); // vector_length(value)? - if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(), - vectorLength, indexType, hasVectorLength, - result))) + vectorLength = parseOptionalOperandAndType( + parser, ParallelOp::getVectorLengthKeyword(), result); + if (vectorLength.hasValue() && failed(*vectorLength)) return failure(); // if()? @@ -237,26 +252,27 @@ if (failed(parseRegions(parser, result))) return failure(); - result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr( - {static_cast(hasAsync ? 1 : 0), - static_cast(waitOperands.size()), - static_cast(hasNumGangs ? 1 : 0), - static_cast(hasNumWorkers ? 1 : 0), - static_cast(hasVectorLength ? 1 : 0), - static_cast(hasIfCond ? 1 : 0), - static_cast(hasSelfCond ? 1 : 0), - static_cast(reductionOperands.size()), - static_cast(copyOperands.size()), - static_cast(copyinOperands.size()), - static_cast(copyoutOperands.size()), - static_cast(createOperands.size()), - static_cast(noCreateOperands.size()), - static_cast(presentOperands.size()), - static_cast(devicePtrOperands.size()), - static_cast(attachOperands.size()), - static_cast(privateOperands.size()), - static_cast(firstprivateOperands.size())})); + result.addAttribute( + ParallelOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr( + {static_cast(async.hasValue() ? 1 : 0), + static_cast(waitOperands.size()), + static_cast(numGangs.hasValue() ? 1 : 0), + static_cast(numWorkers.hasValue() ? 1 : 0), + static_cast(vectorLength.hasValue() ? 1 : 0), + static_cast(hasIfCond ? 1 : 0), + static_cast(hasSelfCond ? 1 : 0), + static_cast(reductionOperands.size()), + static_cast(copyOperands.size()), + static_cast(copyinOperands.size()), + static_cast(copyoutOperands.size()), + static_cast(createOperands.size()), + static_cast(noCreateOperands.size()), + static_cast(presentOperands.size()), + static_cast(devicePtrOperands.size()), + static_cast(attachOperands.size()), + static_cast(privateOperands.size()), + static_cast(firstprivateOperands.size())})); // Additional attributes if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) @@ -270,7 +286,8 @@ // async()? if (Value async = op.async()) - printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")"; + printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " + << async.getType() << ")"; // wait()? printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); @@ -278,17 +295,17 @@ // num_gangs()? if (Value numGangs = op.numGangs()) printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs - << ")"; + << ": " << numGangs.getType() << ")"; // num_workers()? if (Value numWorkers = op.numWorkers()) printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers - << ")"; + << ": " << numWorkers.getType() << ")"; // vector_length()? if (Value vectorLength = op.vectorLength()) printer << " " << ParallelOp::getVectorLengthKeyword() << "(" - << vectorLength << ")"; + << vectorLength << ": " << vectorLength.getType() << ")"; // if()? if (Value ifCond = op.ifCond()) diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -8,8 +8,9 @@ %c0 = constant 0 : index %c10 = constant 10 : index %c1 = constant 1 : index + %async = constant 1 : i64 - acc.parallel async(%c1) { + acc.parallel async(%async: i64) { acc.loop gang vector { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { @@ -35,7 +36,8 @@ // CHECK-NEXT: %{{.*}} = constant 0 : index // CHECK-NEXT: %{{.*}} = constant 10 : index // CHECK-NEXT: %{{.*}} = constant 1 : index -// CHECK-NEXT: acc.parallel async(%{{.*}}) { +// CHECK-NEXT: [[ASYNC:%.*]] = constant 1 : i64 +// CHECK-NEXT: acc.parallel async([[ASYNC]]: i64) { // CHECK-NEXT: acc.loop gang vector { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -113,9 +115,11 @@ %lb = constant 0 : index %st = constant 1 : index %c10 = constant 10 : index + %numGangs = constant 10 : i64 + %numWorkers = constant 10 : i64 acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) { - acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) { + acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -154,8 +158,10 @@ // CHECK-NEXT: [[C0:%.*]] = constant 0 : index // CHECK-NEXT: [[C1:%.*]] = constant 1 : index // CHECK-NEXT: [[C10:%.*]] = constant 10 : index +// CHECK-NEXT: [[NUMGANG:%.*]] = constant 10 : i64 +// CHECK-NEXT: [[NUMWORKERS:%.*]] = constant 10 : i64 // CHECK-NEXT: acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) { -// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]]: i64) num_workers([[NUMWORKERS]]: i64) private([[ARG2]]: memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -267,12 +273,77 @@ func @testparallelop() -> () { - %vectorLength = constant 128 : index - acc.parallel vector_length(%vectorLength) { + %i64value = constant 1 : i64 + %i32value = constant 1 : i32 + %idxValue = constant 1 : index + acc.parallel async(%i64value: i64) { + } + acc.parallel async(%i32value: i32) { + } + acc.parallel async(%idxValue: index) { + } + acc.parallel wait(%i64value: i64) { + } + acc.parallel wait(%i32value: i32) { + } + acc.parallel wait(%idxValue: index) { + } + acc.parallel wait(%i64value: i64, %i32value: i32, %idxValue: index) { + } + acc.parallel num_gangs(%i64value: i64) { + } + acc.parallel num_gangs(%i32value: i32) { + } + acc.parallel num_gangs(%idxValue: index) { + } + acc.parallel num_workers(%i64value: i64) { + } + acc.parallel num_workers(%i32value: i32) { + } + acc.parallel num_workers(%idxValue: index) { + } + acc.parallel vector_length(%i64value: i64) { + } + acc.parallel vector_length(%i32value: i32) { + } + acc.parallel vector_length(%idxValue: index) { } return } -// CHECK: [[VECTORLENGTH:%.*]] = constant 128 : index -// CHECK-NEXT: acc.parallel vector_length([[VECTORLENGTH]]) { +// CHECK-LABEL: func @testparallelop +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK: [[I32VALUE:%.*]] = constant 1 : i32 +// CHECK: [[IDXVALUE:%.*]] = constant 1 : index +// CHECK: acc.parallel async([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel async([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel async([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel wait([[I64VALUE]]: i64, [[I32VALUE]]: i32, [[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel num_workers([[IDXVALUE]]: index) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[I64VALUE]]: i64) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[I32VALUE]]: i32) { +// CHECK-NEXT: } +// CHECK: acc.parallel vector_length([[IDXVALUE]]: index) { // CHECK-NEXT: }