diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -244,14 +244,14 @@ let arguments = (ins OptionalAttr:$collapse, - Optional:$gangNum, - Optional:$gangStatic, - Optional:$workerNum, - Optional:$vectorLength, + Optional:$gangNum, + Optional:$gangStatic, + Optional:$workerNum, + Optional:$vectorLength, UnitAttr:$seq, UnitAttr:$independent, UnitAttr:$auto_, - Variadic:$tileOperands, + Variadic:$tileOperands, Variadic:$privateOperands, OptionalAttr:$reductionOp, Variadic:$reductionOperands, diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -545,10 +545,14 @@ //===----------------------------------------------------------------------===// /// Parse acc.loop operation -/// operation := `acc.loop` `gang`? `vector`? `worker`? -/// `private` `(` value-list `)`? -/// `reduction` `(` value-list `)`? -/// region attr-dict? +/// operation := `acc.loop` +/// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? +/// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? +/// (`vector_length` `(` value `)`)? +/// (`tile` `(` value-list `)`)? +/// (`private` `(` value-list `)`)? +/// (`reduction` `(` value-list `)`)? +/// region attr-dict? static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { Builder &builder = parser.getBuilder(); unsigned executionMapping = OpenACCExecMapping::NONE; @@ -558,7 +562,7 @@ bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false; bool hasGangStatic = false; OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic; - Type intType = builder.getI64Type(); + Type gangNumType, gangStaticType, workerType, vectorLengthType; // gang? if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) @@ -568,9 +572,9 @@ if (succeeded(parser.parseOptionalLParen())) { if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) { hasGangNum = true; - parser.parseColon(); - if (parser.parseOperand(gangNum) || - parser.resolveOperand(gangNum, intType, result.operands)) { + parser.parseEqual(); + if (parser.parseOperand(gangNum) || parser.parseColonType(gangNumType) || + parser.resolveOperand(gangNum, gangNumType, result.operands)) { return failure(); } } @@ -578,9 +582,10 @@ if (succeeded( parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) { hasGangStatic = true; - parser.parseColon(); + parser.parseEqual(); if (parser.parseOperand(gangStatic) || - parser.resolveOperand(gangStatic, intType, result.operands)) { + parser.parseColonType(gangStaticType) || + parser.resolveOperand(gangStatic, gangStaticType, result.operands)) { return failure(); } } @@ -595,8 +600,8 @@ // optional worker operand if (succeeded(parser.parseOptionalLParen())) { hasWorkerNum = true; - if (parser.parseOperand(workerNum) || - parser.resolveOperand(workerNum, intType, result.operands) || + if (parser.parseOperand(workerNum) || parser.parseColonType(workerType) || + parser.resolveOperand(workerNum, workerType, result.operands) || parser.parseRParen()) { return failure(); } @@ -610,7 +615,9 @@ if (succeeded(parser.parseOptionalLParen())) { hasVectorLength = true; if (parser.parseOperand(vectorLength) || - parser.resolveOperand(vectorLength, intType, result.operands) || + parser.parseColonType(vectorLengthType) || + parser.resolveOperand(vectorLength, vectorLengthType, + result.operands) || parser.parseRParen()) { return failure(); } @@ -671,12 +678,14 @@ if (gangNum || gangStatic) { printer << "("; if (gangNum) { - printer << LoopOp::getGangNumKeyword() << ": " << gangNum; + printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " + << gangNum.getType(); if (gangStatic) printer << ", "; } if (gangStatic) - printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic; + printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " + << gangStatic.getType(); printer << ")"; } } @@ -686,7 +695,7 @@ // Print optional worker operand if present if (Value workerNum = op.workerNum()) - printer << "(" << workerNum << ")"; + printer << "(" << workerNum << ": " << workerNum.getType() << ")"; } if (execMapping & OpenACCExecMapping::VECTOR) { @@ -694,7 +703,7 @@ // Print optional vector operand if present if (Value vectorLength = op.vectorLength()) - printer << "(" << vectorLength << ")"; + printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; } // tile()? diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | FileCheck %s // Verify the printed output can be parsed. -// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s // Verify the generic form can be parsed. -// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { %c0 = constant 0 : index @@ -58,6 +58,8 @@ // CHECK-NEXT: return %{{.*}} : memref<10x10xf32> // CHECK-NEXT: } +// ----- + func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> { %c0 = constant 0 : index %c10 = constant 10 : index @@ -110,6 +112,7 @@ // CHECK-NEXT: return %{{.*}} : memref<10x10xf32> // CHECK-NEXT: } +// ----- func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> { %lb = constant 0 : index @@ -192,85 +195,133 @@ // CHECK-NEXT: return %{{.*}} : memref<10xf32> // CHECK-NEXT: } -func @testop(%a: memref<10xf32>) -> () { - %workerNum = constant 1 : i64 - %vectorLength = constant 128 : i64 - %gangNum = constant 8 : i64 - %gangStatic = constant 2 : i64 - %tileSize = constant 2 : i64 +// ----- + +func @testloopop() -> () { + %i64Value = constant 1 : i64 + %i32Value = constant 128 : i32 + %idxValue = constant 8 : index + acc.loop gang worker vector { "some.op"() : () -> () acc.yield } - acc.loop gang(num: %gangNum) { + acc.loop gang(num=%i64Value: i64) { + "some.op"() : () -> () + acc.yield + } + acc.loop gang(static=%i64Value: i64) { + "some.op"() : () -> () + acc.yield + } + acc.loop worker(%i64Value: i64) { + "some.op"() : () -> () + acc.yield + } + acc.loop worker(%i32Value: i32) { + "some.op"() : () -> () + acc.yield + } + acc.loop worker(%idxValue: index) { + "some.op"() : () -> () + acc.yield + } + acc.loop vector(%i64Value: i64) { + "some.op"() : () -> () + acc.yield + } + acc.loop vector(%i32Value: i32) { "some.op"() : () -> () acc.yield } - acc.loop gang(static: %gangStatic) { + acc.loop vector(%idxValue: index) { "some.op"() : () -> () acc.yield } - acc.loop worker(%workerNum) { + acc.loop gang(num=%i64Value: i64) worker vector { "some.op"() : () -> () acc.yield } - acc.loop vector(%vectorLength) { + acc.loop gang(num=%i64Value: i64, static=%i64Value: i64) worker(%i64Value: i64) vector(%i64Value: i64) { "some.op"() : () -> () acc.yield } - acc.loop gang(num: %gangNum) worker vector { + acc.loop gang(num=%i32Value: i32, static=%idxValue: index) { "some.op"() : () -> () acc.yield } - acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) { + acc.loop tile(%i64Value: i64, %i64Value: i64) { "some.op"() : () -> () acc.yield } - acc.loop tile(%tileSize : i64, %tileSize : i64) { + acc.loop tile(%i32Value: i32, %i32Value: i32) { "some.op"() : () -> () acc.yield } return } -// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64 -// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64 -// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64 -// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64 -// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64 -// CHECK-NEXT: acc.loop gang worker vector { +// CHECK: [[I64VALUE:%.*]] = constant 1 : i64 +// CHECK-NEXT: [[I32VALUE:%.*]] = constant 128 : i32 +// CHECK-NEXT: [[IDXVALUE:%.*]] = constant 8 : index +// CHECK: acc.loop gang worker vector { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop gang(num=[[I64VALUE]]: i64) { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop gang(static=[[I64VALUE]]: i64) { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop worker([[I64VALUE]]: i64) { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop worker([[I32VALUE]]: i32) { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop worker([[IDXVALUE]]: index) { +// CHECK-NEXT: "some.op"() : () -> () +// CHECK-NEXT: acc.yield +// CHECK-NEXT: } +// CHECK: acc.loop vector([[I64VALUE]]: i64) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) { +// CHECK: acc.loop vector([[I32VALUE]]: i32) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) { +// CHECK: acc.loop vector([[IDXVALUE]]: index) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) { +// CHECK: acc.loop gang(num=[[I64VALUE]]: i64) worker vector { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) { +// CHECK: acc.loop gang(num=[[I64VALUE]]: i64, static=[[I64VALUE]]: i64) worker([[I64VALUE]]: i64) vector([[I64VALUE]]: i64) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector { +// CHECK: acc.loop gang(num=[[I32VALUE]]: i32, static=[[IDXVALUE]]: index) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) { +// CHECK: acc.loop tile([[I64VALUE]]: i64, [[I64VALUE]]: i64) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } -// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) { +// CHECK: acc.loop tile([[I32VALUE]]: i32, [[I32VALUE]]: i32) { // CHECK-NEXT: "some.op"() : () -> () // CHECK-NEXT: acc.yield // CHECK-NEXT: } +// ----- func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { %i64value = constant 1 : i64