diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -31,12 +31,11 @@ /// 2.9.2. gang /// 2.9.3. worker /// 2.9.4. vector -/// 2.9.5. seq /// /// Value can be combined bitwise to reflect the mapping applied to the /// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be /// combined and the final mapping value would be 5 (4 | 1). -enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 }; +enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 }; } // end namespace acc } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -224,6 +224,14 @@ let arguments = (ins OptionalAttr:$collapse, + Optional:$gangNum, + Optional:$gangStatic, + Optional:$workerNum, + Optional:$vectorLength, + OptionalAttr:$loopSeq, + OptionalAttr:$loopIndependent, + OptionalAttr:$loopAuto, + Variadic:$tileOperands, Variadic:$privateOperands, OptionalAttr:$reductionOp, Variadic:$reductionOperands); @@ -235,10 +243,13 @@ let extraClassDeclaration = [{ static StringRef getCollapseAttrName() { return "collapse"; } static StringRef getExecutionMappingAttrName() { return "exec_mapping"; } - static StringRef getGangAttrName() { return "gang"; } - static StringRef getSeqAttrName() { return "seq"; } - static StringRef getVectorAttrName() { return "vector"; } - static StringRef getWorkerAttrName() { return "worker"; } + static StringRef getGangKeyword() { return "gang"; } + static StringRef getGangNumKeyword() { return "num"; } + static StringRef getGangStaticKeyword() { return "static"; } + static StringRef getSeqKeyword() { return "seq"; } + static StringRef getVectorKeyword() { return "vector"; } + static StringRef getWorkerKeyword() { return "worker"; } + static StringRef getTileKeyword() { return "tile"; } static StringRef getPrivateKeyword() { return "private"; } static StringRef getReductionKeyword() { return "reduction"; } }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -476,7 +476,7 @@ //===----------------------------------------------------------------------===// /// Parse acc.loop operation -/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`? +/// operation := `acc.loop` `gang`? `vector`? `worker`? /// `private` `(` value-list `)`? /// `reduction` `(` value-list `)`? /// region attr-dict? @@ -486,22 +486,73 @@ unsigned executionMapping = 0; SmallVector operandTypes; SmallVector privateOperands, reductionOperands; + SmallVector tileOperands; + bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false; + bool hasGangStatic = false; + OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic; + Type intType = builder.getI64Type(); // gang? - if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName()))) + if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) executionMapping |= OpenACCExecMapping::GANG; - // vector? - if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName()))) - executionMapping |= OpenACCExecMapping::VECTOR; + // optional gang operand + if (succeeded(parser.parseOptionalLParen())) { + if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) { + hasGangNum = true; + parser.parseColon(); + if (parser.parseOperand(gangNum) || + parser.resolveOperand(gangNum, intType, result.operands)) { + return failure(); + } + } + parser.parseOptionalComma(); + if (succeeded( + parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) { + hasGangStatic = true; + parser.parseColon(); + if (parser.parseOperand(gangStatic) || + parser.resolveOperand(gangStatic, intType, result.operands)) { + return failure(); + } + } + if (failed(parser.parseRParen())) { + return failure(); + } + } // worker? - if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName()))) + if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) executionMapping |= OpenACCExecMapping::WORKER; - // seq? - if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName()))) - executionMapping |= OpenACCExecMapping::SEQ; + // optional worker operand + if (succeeded(parser.parseOptionalLParen())) { + hasWorkerNum = true; + if (parser.parseOperand(workerNum) || + parser.resolveOperand(workerNum, intType, result.operands) || + parser.parseRParen()) { + return failure(); + } + } + + // vector? + if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) + executionMapping |= OpenACCExecMapping::VECTOR; + + // optional vector operand + if (succeeded(parser.parseOptionalLParen())) { + hasVectorLength = true; + if (parser.parseOperand(vectorLength) || + parser.resolveOperand(vectorLength, intType, result.operands) || + parser.parseRParen()) { + return failure(); + } + } + + // tile()? + if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, + operandTypes, result))) + return failure(); // private()? if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), @@ -526,7 +577,12 @@ result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr( - {static_cast(privateOperands.size()), + {static_cast(hasGangNum ? 1 : 0), + static_cast(hasGangStatic ? 1 : 0), + static_cast(hasWorkerNum ? 1 : 0), + static_cast(hasVectorLength ? 1 : 0), + static_cast(tileOperands.size()), + static_cast(privateOperands.size()), static_cast(reductionOperands.size())})); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) @@ -544,17 +600,45 @@ ? op.getAttrOfType(LoopOp::getExecutionMappingAttrName()) .getInt() : 0; - if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG) - printer << " " << LoopOp::getGangAttrName(); - if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER) - printer << " " << LoopOp::getWorkerAttrName(); + if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG) { + printer << " " << LoopOp::getGangKeyword(); + Value gangNum = op.gangNum(); + Value gangStatic = op.gangStatic(); + + // Print optional gang operands + if (gangNum || gangStatic) { + printer << "("; + if (gangNum) + printer << LoopOp::getGangNumKeyword() << ": " << gangNum; + if (gangNum && gangStatic) + printer << ", "; + if (gangStatic) + printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic; + printer << ")"; + } + } - if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR) - printer << " " << LoopOp::getVectorAttrName(); + if ((execMapping & OpenACCExecMapping::WORKER) == + OpenACCExecMapping::WORKER) { + printer << " " << LoopOp::getWorkerKeyword(); + + // Print optional worker operand if present + if (Value workerNum = op.workerNum()) + printer << "(" << workerNum << ")"; + } + + if ((execMapping & OpenACCExecMapping::VECTOR) == + OpenACCExecMapping::VECTOR) { + printer << " " << LoopOp::getVectorKeyword(); + + // Print optional vector operand if present + if (Value vectorLength = op.vectorLength()) + printer << "(" << vectorLength << ")"; + } - if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ) - printer << " " << LoopOp::getSeqAttrName(); + // tile()? + printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer); // private()? printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -62,7 +62,7 @@ %c1 = constant 1 : index acc.parallel { - acc.loop seq { + acc.loop { scf.for %arg3 = %c0 to %c10 step %c1 { scf.for %arg4 = %c0 to %c10 step %c1 { scf.for %arg5 = %c0 to %c10 step %c1 { @@ -76,7 +76,7 @@ } } acc.yield - } + } attributes {loopSeq = true} acc.yield } @@ -88,7 +88,7 @@ // CHECK-NEXT: %{{.*}} = constant 10 : index // CHECK-NEXT: %{{.*}} = constant 1 : index // CHECK-NEXT: acc.parallel { -// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -102,7 +102,7 @@ // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } +// CHECK-NEXT: } attributes {loopSeq = true} // CHECK-NEXT: acc.yield // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10x10xf32> @@ -128,7 +128,7 @@ acc.yield } - acc.loop seq { + acc.loop { // for i = 0 to 10 step 1 // d[x] += c[i] scf.for %i = %lb to %c10 step %st { @@ -138,7 +138,7 @@ store %z, %d[%x] : memref<10xf32> } acc.yield - } + } attributes {loopSeq = true} } acc.yield } @@ -167,7 +167,7 @@ // CHECK-NEXT: } // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop seq { +// CHECK-NEXT: acc.loop { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32> @@ -175,7 +175,7 @@ // CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: acc.yield -// CHECK-NEXT: } +// CHECK-NEXT: } attributes {loopSeq = true} // CHECK-NEXT: } // CHECK-NEXT: acc.yield // CHECK-NEXT: } @@ -184,4 +184,51 @@ // CHECK-NEXT: acc.terminator // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10xf32> -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +func @testop() -> () { + %workerNum = constant 1 : i64 + %vectorLength = constant 128 : i64 + %gangNum = constant 8 : i64 + %gangStatic = constant 2 : i64 + %tileSize = constant 2 : i64 + acc.loop gang worker vector { + } + acc.loop gang(num: %gangNum) { + } + acc.loop gang(static: %gangStatic) { + } + acc.loop worker(%workerNum) { + } + acc.loop vector(%vectorLength) { + } + acc.loop gang(num: %gangNum) worker vector { + } + acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) { + } + acc.loop tile(%tileSize : i64, %tileSize : i64) { + } + return +} + +// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64 +// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64 +// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64 +// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64 +// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64 +// CHECK-NEXT: acc.loop gang worker vector { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) { +// CHECK-NEXT: } +// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) { +// CHECK-NEXT: }