diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -139,7 +139,6 @@ Optional:$varPtrPtr, Variadic:$bounds, /* rank-0 to rank-{n-1} */ DefaultValuedAttr:$dataClause, - OptionalAttr:$decomposedFrom, DefaultValuedAttr:$structured, DefaultValuedAttr:$implicit, OptionalAttr:$name); @@ -242,7 +241,6 @@ OpenACC_PointerLikeTypeInterface:$accPtr, Variadic:$bounds, DefaultValuedAttr:$dataClause, - OptionalAttr:$decomposedFrom, DefaultValuedAttr:$structured, DefaultValuedAttr:$implicit, OptionalAttr:$name); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -88,10 +88,13 @@ // CopyinOp //===----------------------------------------------------------------------===// LogicalResult acc::CopyinOp::verify() { + // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_copyin && - getDataClause() != acc::DataClause::acc_copyin_readonly) { + getDataClause() != acc::DataClause::acc_copyin_readonly && + getDataClause() != acc::DataClause::acc_copy) { return emitError( - "data clause associated with copyin operation must match its intent"); + "data clause associated with copyin operation must match its intent" + " or specify original clause this operation was decomposed from"); } return success(); } @@ -104,16 +107,22 @@ // CreateOp //===----------------------------------------------------------------------===// LogicalResult acc::CreateOp::verify() { + // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_create && - getDataClause() != acc::DataClause::acc_create_zero) { + getDataClause() != acc::DataClause::acc_create_zero && + getDataClause() != acc::DataClause::acc_copyout && + getDataClause() != acc::DataClause::acc_copyout_zero) { return emitError( - "data clause associated with create operation must match its intent"); + "data clause associated with create operation must match its intent" + " or specify original clause this operation was decomposed from"); } return success(); } bool acc::CreateOp::isCreateZero() { - return getDataClause() == acc::DataClause::acc_create_zero; + // The zero modifier is encoded in the data clause. + return getDataClause() == acc::DataClause::acc_create_zero || + getDataClause() == acc::DataClause::acc_copyout_zero; } //===----------------------------------------------------------------------===// @@ -142,7 +151,13 @@ // GetDevicePtrOp //===----------------------------------------------------------------------===// LogicalResult acc::GetDevicePtrOp::verify() { - if (getDataClause() != acc::DataClause::acc_getdeviceptr) { + // This operation is also created for use in unstructured constructs + // when we need an "accPtr" to feed to exit operation. Thus we test + // for those cases as well: + if (getDataClause() != acc::DataClause::acc_getdeviceptr && + getDataClause() != acc::DataClause::acc_copyout && + getDataClause() != acc::DataClause::acc_delete && + getDataClause() != acc::DataClause::acc_detach) { return emitError("getDevicePtr mismatch"); } return success(); @@ -152,10 +167,13 @@ // CopyoutOp //===----------------------------------------------------------------------===// LogicalResult acc::CopyoutOp::verify() { + // Test for all clauses this operation can be decomposed from: if (getDataClause() != acc::DataClause::acc_copyout && - getDataClause() != acc::DataClause::acc_copyout_zero) { + getDataClause() != acc::DataClause::acc_copyout_zero && + getDataClause() != acc::DataClause::acc_copy) { return emitError( - "data clause associated with copyout operation must match its intent"); + "data clause associated with copyout operation must match its intent" + " or specify original clause this operation was decomposed from"); } if (!getVarPtr() || !getAccPtr()) { return emitError("must have both host and device pointers"); @@ -171,9 +189,13 @@ // DeleteOp //===----------------------------------------------------------------------===// LogicalResult acc::DeleteOp::verify() { - if (getDataClause() != acc::DataClause::acc_delete) { + // Test for all clauses this operation can be decomposed from: + if (getDataClause() != acc::DataClause::acc_delete && + getDataClause() != acc::DataClause::acc_create && + getDataClause() != acc::DataClause::acc_create_zero) { return emitError( - "data clause associated with delete operation must match its intent"); + "data clause associated with delete operation must match its intent" + " or specify original clause this operation was decomposed from"); } if (!getVarPtr() && !getAccPtr()) { return emitError("must have either host or device pointer"); @@ -185,9 +207,12 @@ // DetachOp //===----------------------------------------------------------------------===// LogicalResult acc::DetachOp::verify() { - if (getDataClause() != acc::DataClause::acc_detach) { + // Test for all clauses this operation can be decomposed from: + if (getDataClause() != acc::DataClause::acc_detach && + getDataClause() != acc::DataClause::acc_attach) { return emitError( - "data clause associated with detach operation must match its intent"); + "data clause associated with detach operation must match its intent" + " or specify original clause this operation was decomposed from"); } if (!getVarPtr() && !getAccPtr()) { return emitError("must have either host or device pointer"); diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -994,19 +994,19 @@ acc.kernels dataOperands(%copyinreadonly : memref<10xf32>) { } - %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 3} + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 3} acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { } - acc.copyout accPtr(%copyinfromcopy : memref<10xf32>) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3} + acc.copyout accPtr(%copyinfromcopy : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 3} %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> %createimplicit = acc.create varPtr(%c : memref<10x20xf32>) -> memref<10x20xf32> {implicit = true} acc.parallel dataOperands(%create, %createimplicit : memref<10xf32>, memref<10x20xf32>) { } - acc.delete accPtr(%create : memref<10xf32>) {decomposedFrom = 7} - acc.delete accPtr(%createimplicit : memref<10x20xf32>) {decomposedFrom = 7, implicit = true} + acc.delete accPtr(%create : memref<10xf32>) {dataClause = 7} + acc.delete accPtr(%createimplicit : memref<10x20xf32>) {dataClause = 7, implicit = true} - %copyoutzero = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 5} + %copyoutzero = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 5} acc.parallel dataOperands(%copyoutzero: memref<10xf32>) { } acc.copyout accPtr(%copyoutzero : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 5} @@ -1014,12 +1014,12 @@ %attach = acc.attach varPtr(%b : memref>) -> memref> acc.parallel dataOperands(%attach : memref>) { } - acc.detach accPtr(%attach : memref>) {decomposedFrom = 10} + acc.detach accPtr(%attach : memref>) {dataClause = 10} - %copyinparent = acc.copyin varPtr(%a : memref<10xf32>) varPtrPtr(%b : memref>) -> memref<10xf32> {decomposedFrom = 3} + %copyinparent = acc.copyin varPtr(%a : memref<10xf32>) varPtrPtr(%b : memref>) -> memref<10xf32> {dataClause = 3} acc.parallel dataOperands(%copyinparent : memref<10xf32>) { } - acc.copyout accPtr(%copyinparent : memref<10xf32>) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3} + acc.copyout accPtr(%copyinparent : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 3} %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -1037,10 +1037,10 @@ } %bounds1partial = acc.bounds lowerbound(%c4 : index) upperbound(%c9 : index) stride(%c1 : index) - %copyinpartial = acc.copyin varPtr(%a : memref<10xf32>) bounds(%bounds1partial) -> memref<10xf32> {decomposedFrom = 3} + %copyinpartial = acc.copyin varPtr(%a : memref<10xf32>) bounds(%bounds1partial) -> memref<10xf32> {dataClause = 3} acc.parallel dataOperands(%copyinpartial : memref<10xf32>) { } - acc.copyout accPtr(%copyinpartial : memref<10xf32>) bounds(%bounds1partial) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3} + acc.copyout accPtr(%copyinpartial : memref<10xf32>) bounds(%bounds1partial) to varPtr(%a : memref<10xf32>) {dataClause = 3} return } @@ -1058,28 +1058,28 @@ // CHECK: [[COPYINRO:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 2 : i64} // CHECK-NEXT: acc.kernels dataOperands([[COPYINRO]] : memref<10xf32>) { // CHECK-NEXT: } -// CHECK: [[COPYINCOPY:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 3 : i64} +// CHECK: [[COPYINCOPY:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 3 : i64} // CHECK-NEXT: acc.serial dataOperands([[COPYINCOPY]] : memref<10xf32>) { // CHECK-NEXT: } -// CHECK-NEXT: acc.copyout accPtr([[COPYINCOPY]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64} +// CHECK-NEXT: acc.copyout accPtr([[COPYINCOPY]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64} // CHECK: [[CREATE:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> // CHECK-NEXT: [[CREATEIMP:%.*]] = acc.create varPtr([[ARGC]] : memref<10x20xf32>) -> memref<10x20xf32> {implicit = true} // CHECK-NEXT: acc.parallel dataOperands([[CREATE]], [[CREATEIMP]] : memref<10xf32>, memref<10x20xf32>) { // CHECK-NEXT: } -// CHECK-NEXT: acc.delete accPtr([[CREATE]] : memref<10xf32>) {decomposedFrom = 7 : i64} -// CHECK-NEXT: acc.delete accPtr([[CREATEIMP]] : memref<10x20xf32>) {decomposedFrom = 7 : i64, implicit = true} -// CHECK: [[COPYOUTZ:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 5 : i64} +// CHECK-NEXT: acc.delete accPtr([[CREATE]] : memref<10xf32>) {dataClause = 7 : i64} +// CHECK-NEXT: acc.delete accPtr([[CREATEIMP]] : memref<10x20xf32>) {dataClause = 7 : i64, implicit = true} +// CHECK: [[COPYOUTZ:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 5 : i64} // CHECK-NEXT: acc.parallel dataOperands([[COPYOUTZ]] : memref<10xf32>) { // CHECK-NEXT: } // CHECK-NEXT: acc.copyout accPtr([[COPYOUTZ]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 5 : i64} // CHECK: [[ATTACH:%.*]] = acc.attach varPtr([[ARGB]] : memref>) -> memref> // CHECK-NEXT: acc.parallel dataOperands([[ATTACH]] : memref>) { // CHECK-NEXT: } -// CHECK-NEXT: acc.detach accPtr([[ATTACH]] : memref>) {decomposedFrom = 10 : i64} -// CHECK: [[COPYINP:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) varPtrPtr([[ARGB]] : memref>) -> memref<10xf32> {decomposedFrom = 3 : i64} +// CHECK-NEXT: acc.detach accPtr([[ATTACH]] : memref>) {dataClause = 10 : i64} +// CHECK: [[COPYINP:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) varPtrPtr([[ARGB]] : memref>) -> memref<10xf32> {dataClause = 3 : i64} // CHECK-NEXT: acc.parallel dataOperands([[COPYINP]] : memref<10xf32>) { // CHECK-NEXT: } -// CHECK-NEXT: acc.copyout accPtr([[COPYINP]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64} +// CHECK-NEXT: acc.copyout accPtr([[COPYINP]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64} // CHECK-DAG: [[CON0:%.*]] = arith.constant 0 : index // CHECK-DAG: [[CON1:%.*]] = arith.constant 1 : index // CHECK-DAG: [[CON4:%.*]] = arith.constant 4 : index @@ -1092,10 +1092,10 @@ // CHECK-NEXT: acc.parallel dataOperands([[COPYINF1]], [[COPYINF2]] : memref<10xf32>, memref<10x20xf32>) { // CHECK-NEXT: } // CHECK: [[BOUNDS1P:%.*]] = acc.bounds lowerbound([[CON4]] : index) upperbound([[CON9]] : index) stride([[CON1]] : index) -// CHECK-NEXT: [[COPYINPART:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) bounds([[BOUNDS1P]]) -> memref<10xf32> {decomposedFrom = 3 : i64} +// CHECK-NEXT: [[COPYINPART:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) bounds([[BOUNDS1P]]) -> memref<10xf32> {dataClause = 3 : i64} // CHECK-NEXT: acc.parallel dataOperands([[COPYINPART]] : memref<10xf32>) { // CHECK-NEXT: } -// CHECK-NEXT: acc.copyout accPtr([[COPYINPART]] : memref<10xf32>) bounds([[BOUNDS1P]]) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64} +// CHECK-NEXT: acc.copyout accPtr([[COPYINPART]] : memref<10xf32>) bounds([[BOUNDS1P]]) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64} // ----- @@ -1103,7 +1103,7 @@ %copyin = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {structured = false} acc.enter_data dataOperands(%copyin : memref<10xf32>) - %devptr = acc.getdeviceptr varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 4} + %devptr = acc.getdeviceptr varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 4} acc.exit_data dataOperands(%devptr : memref<10xf32>) acc.copyout accPtr(%devptr : memref<10xf32>) to varPtr(%a : memref<10xf32>) {structured = false} @@ -1113,6 +1113,6 @@ // CHECK: func.func @testunstructuredclauseops([[ARGA:%.*]]: memref<10xf32>) { // CHECK: [[COPYIN:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {structured = false} // CHECK-NEXT: acc.enter_data dataOperands([[COPYIN]] : memref<10xf32>) -// CHECK: [[DEVPTR:%.*]] = acc.getdeviceptr varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 4 : i64} +// CHECK: [[DEVPTR:%.*]] = acc.getdeviceptr varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 4 : i64} // CHECK-NEXT: acc.exit_data dataOperands([[DEVPTR]] : memref<10xf32>) // CHECK-NEXT: acc.copyout accPtr([[DEVPTR]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {structured = false}