diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -81,11 +81,11 @@ ``` }]; - let arguments = (ins Optional:$async, - Variadic:$waitOperands, - Optional:$numGangs, - Optional:$numWorkers, - Optional:$vectorLength, + let arguments = (ins Optional:$async, + Variadic:$waitOperands, + Optional:$numGangs, + Optional:$numWorkers, + Optional:$vectorLength, Optional:$ifCond, Optional:$selfCond, OptionalAttr:$reductionOp, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -138,12 +138,12 @@ bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false; bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false; - Type indexType = builder.getIndexType(); + Type i64Type = builder.getIntegerType(64); Type i1Type = builder.getI1Type(); // async()? if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async, - indexType, hasAsync, result))) + i64Type, hasAsync, result))) return failure(); // wait()? @@ -153,18 +153,17 @@ // num_gangs(value)? if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(), - numGangs, indexType, hasNumGangs, result))) + numGangs, i64Type, hasNumGangs, result))) return failure(); // num_workers(value)? if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(), - numWorkers, indexType, hasNumWorkers, - result))) + numWorkers, i64Type, hasNumWorkers, result))) return failure(); // vector_length(value)? if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(), - vectorLength, indexType, hasVectorLength, + vectorLength, i64Type, hasVectorLength, result))) return failure(); diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -8,8 +8,9 @@ %c0 = constant 0 : index %c10 = constant 10 : index %c1 = constant 1 : index + %async = constant 1 : i64 - acc.parallel async(%c1) { + acc.parallel async(%async) { acc.loop gang vector { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { @@ -35,7 +36,8 @@ // CHECK-NEXT: %{{.*}} = constant 0 : index // CHECK-NEXT: %{{.*}} = constant 10 : index // CHECK-NEXT: %{{.*}} = constant 1 : index -// CHECK-NEXT: acc.parallel async(%{{.*}}) { +// CHECK-NEXT: [[ASYNC:%.*]] = constant 1 : i64 +// CHECK-NEXT: acc.parallel async([[ASYNC]]) { // CHECK-NEXT: acc.loop gang vector { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -113,9 +115,11 @@ %lb = constant 0 : index %st = constant 1 : index %c10 = constant 10 : index + %numGangs = constant 10 : i64 + %numWorkers = constant 10 : i64 acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) { - acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) { + acc.parallel num_gangs(%numGangs) num_workers(%numWorkers) private(%c : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -154,8 +158,10 @@ // CHECK-NEXT: [[C0:%.*]] = constant 0 : index // CHECK-NEXT: [[C1:%.*]] = constant 1 : index // CHECK-NEXT: [[C10:%.*]] = constant 10 : index +// CHECK-NEXT: [[NUMGANG:%.*]] = constant 10 : i64 +// CHECK-NEXT: [[NUMWORKERS:%.*]] = constant 10 : i64 // CHECK-NEXT: acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) { -// CHECK-NEXT: acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]]) num_workers([[NUMWORKERS]]) private([[ARG2]]: memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -267,12 +273,24 @@ func @testparallelop() -> () { - %vectorLength = constant 128 : index + %vectorLength = constant 128 : i64 + %vectorLength32 = constant 128 : i64 acc.parallel vector_length(%vectorLength) { } + acc.parallel vector_length(%vectorLength32) { + } + %wait = constant 1 : i64 + acc.parallel wait(%wait: i64) { + } return } -// CHECK: [[VECTORLENGTH:%.*]] = constant 128 : index +// CHECK: [[VECTORLENGTH:%.*]] = constant 128 : i64 +// CHECK: [[VECTORLENGTH32:%.*]] = constant 128 : i64 // CHECK-NEXT: acc.parallel vector_length([[VECTORLENGTH]]) { // CHECK-NEXT: } +// CHECK-NEXT: acc.parallel vector_length([[VECTORLENGTH32]]) { +// CHECK-NEXT: } +// CHECK-NEXT: [[WAIT:%.*]] = constant 1 : i64 +// CHECK-NEXT: acc.parallel wait([[WAIT]]: i64) { +// CHECK-NEXT: }