diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -656,7 +656,7 @@ UnitAttr:$asyncAttr, Variadic:$waitOperands, UnitAttr:$waitAttr, - Optional:$numGangs, + Variadic:$numGangs, Optional:$numWorkers, Optional:$vectorLength, Optional:$ifCond, @@ -802,7 +802,7 @@ UnitAttr:$asyncAttr, Variadic:$waitOperands, UnitAttr:$waitAttr, - Optional:$numGangs, + Variadic:$numGangs, Optional:$numWorkers, Optional:$vectorLength, Optional:$ifCond, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -573,7 +573,7 @@ Value ParallelOp::getDataOperand(unsigned i) { unsigned numOptional = getAsync() ? 1 : 0; - numOptional += getNumGangs() ? 1 : 0; + numOptional += getNumGangs().size(); numOptional += getNumWorkers() ? 1 : 0; numOptional += getVectorLength() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; @@ -590,6 +590,8 @@ *this, getReductionRecipes(), getReductionOperands(), "reduction", "reductions", false))) return failure(); + if (getNumGangs().size() > 3) + return emitOpError() << "num_gangs expects a maximum of 3 values"; return checkDataOperands(*this, getDataClauseOperands()); } @@ -631,12 +633,18 @@ Value KernelsOp::getDataOperand(unsigned i) { unsigned numOptional = getAsync() ? 1 : 0; + numOptional += getWaitOperands().size(); + numOptional += getNumGangs().size(); + numOptional += getNumWorkers() ? 1 : 0; + numOptional += getVectorLength() ? 1 : 0; numOptional += getIfCond() ? 1 : 0; numOptional += getSelfCond() ? 1 : 0; - return getOperand(getWaitOperands().size() + numOptional + i); + return getOperand(numOptional + i); } LogicalResult acc::KernelsOp::verify() { + if (getNumGangs().size() > 3) + return emitOpError() << "num_gangs expects a maximum of 3 values"; return checkDataOperands(*this, getDataClauseOperands()); } diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -486,3 +486,10 @@ "test.openacc_dummy_op"() : () -> () acc.yield } + +// ----- + +%i64value = arith.constant 1 : i64 +// expected-error@+1 {{num_gangs expects a maximum of 3 values}} +acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) { +} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -443,6 +443,8 @@ } acc.parallel num_gangs(%idxValue: index) { } + acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) { + } acc.parallel num_workers(%i64value: i64) { } acc.parallel num_workers(%i32value: i32) { @@ -494,6 +496,8 @@ // CHECK-NEXT: } // CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) { // CHECK-NEXT: } +// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) { +// CHECK-NEXT: } // CHECK: acc.parallel num_workers([[I64VALUE]] : i64) { // CHECK-NEXT: } // CHECK: acc.parallel num_workers([[I32VALUE]] : i32) {