diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -638,6 +638,7 @@ mlir::Value workerNum; mlir::Value vectorNum; mlir::Value gangNum; + mlir::Value gangDim; mlir::Value gangStatic; llvm::SmallVector tileOperands, privateOperands, reductionOperands; @@ -720,6 +721,7 @@ llvm::SmallVector operands; llvm::SmallVector operandSegments; addOperand(operands, operandSegments, gangNum); + addOperand(operands, operandSegments, gangDim); addOperand(operands, operandSegments, gangStatic); addOperand(operands, operandSegments, workerNum); addOperand(operands, operandSegments, vectorNum); diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -815,6 +815,7 @@ let arguments = (ins OptionalAttr:$collapse, Optional:$gangNum, + Optional:$gangDim, Optional:$gangStatic, Optional:$workerNum, Optional:$vectorLength, @@ -836,13 +837,14 @@ let extraClassDeclaration = [{ static StringRef getAutoAttrStrName() { return "auto"; } static StringRef getGangNumKeyword() { return "num"; } + static StringRef getGangDimKeyword() { return "dim"; } static StringRef getGangStaticKeyword() { return "static"; } }]; let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ oilist( - `gang` `` custom($gangNum, type($gangNum), $gangStatic, type($gangStatic), $hasGang) + `gang` `` custom($gangNum, type($gangNum), $gangDim, type($gangDim), $gangStatic, type($gangStatic), $hasGang) | `worker` `` custom($workerNum, type($workerNum), $hasWorker) | `vector` `` custom($vectorLength, type($vectorLength), $hasVector) | `private` `(` $privateOperands `:` type($privateOperands) `)` diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -406,14 +406,15 @@ return success(); } -static ParseResult -parseGangClause(OpAsmParser &parser, - std::optional &gangNum, - Type &gangNumType, - std::optional &gangStatic, - Type &gangStaticType, UnitAttr &hasGang) { +static ParseResult parseGangClause( + OpAsmParser &parser, std::optional &gangNum, + Type &gangNumType, std::optional &gangDim, + Type &gangDimType, + std::optional &gangStatic, + Type &gangStaticType, UnitAttr &hasGang) { hasGang = UnitAttr::get(parser.getBuilder().getContext()); gangNum = std::nullopt; + gangDim = std::nullopt; gangStatic = std::nullopt; bool needComa = false; @@ -432,6 +433,9 @@ if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), gangNum, gangNumType, needComa, newValue))) return failure(); + if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), gangDim, + gangDimType, needComa, newValue))) + return failure(); if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(), gangStatic, gangStaticType, needComa, newValue))) @@ -447,9 +451,9 @@ break; } - if (!gangNum && !gangStatic) { + if (!gangNum && !gangDim && !gangStatic) { parser.emitError(parser.getCurrentLocation(), - "expect num and/or static value(s)"); + "expect at least one of num, dim or static values"); return failure(); } @@ -460,13 +464,19 @@ } void printGangClause(OpAsmPrinter &p, Operation *op, Value gangNum, - Type gangNumType, Value gangStatic, Type gangStaticType, - UnitAttr hasGang) { - if (gangNum || gangStatic) { + Type gangNumType, Value gangDim, Type gangDimType, + Value gangStatic, Type gangStaticType, UnitAttr hasGang) { + if (gangNum || gangStatic || gangDim) { p << "("; if (gangNum) { p << LoopOp::getGangNumKeyword() << "=" << gangNum << " : " << gangNumType; + if (gangStatic || gangDim) + p << ", "; + } + if (gangDim) { + p << LoopOp::getGangDimKeyword() << "=" << gangDim << " : " + << gangDimType; if (gangStatic) p << ", "; } @@ -535,6 +545,10 @@ if (getRegion().empty()) return emitError("expected non-empty body."); + // num is not allowed when dim is specified. + if (getGangNum() && getGangDim()) + return emitError("num is not allowed when dim is specified"); + return success(); } diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -277,8 +277,17 @@ // ----- -// expected-error@+1 {{expect num and/or static value(s)}} +// expected-error@+1 {{expect at least one of num, dim or static values}} acc.loop gang() { "test.openacc_dummy_op"() : () -> () acc.yield } + +// ----- + +%i64Value = arith.constant 1 : i64 +// expected-error@+1 {{num is not allowed when dim is specified}} +acc.loop gang(dim=%i64Value: i64, num=%i64Value: i64) { + "test.openacc_dummy_op"() : () -> () + acc.yield +} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -266,6 +266,10 @@ "test.openacc_dummy_op"() : () -> () acc.yield } + acc.loop gang(dim=%i64Value : i64, static=%i64Value: i64) { + "test.openacc_dummy_op"() : () -> () + acc.yield + } return } @@ -332,6 +336,10 @@ // CHECK-NEXT: "test.openacc_dummy_op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } +// CHECK: acc.loop gang(dim=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64) { +// CHECK-NEXT: "test.openacc_dummy_op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } // -----