diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 --- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90 @@ -442,18 +442,19 @@ ! CHECK: acc.yield ! CHECK-NEXT: }{{$}} - !$acc parallel loop private(a) firstprivate(b) - DO i = 1, n - a(i) = b(i) - END DO - -! CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref>) private(%[[A]] : !fir.ref>) { -! CHECK: acc.loop private(%[[A]] : !fir.ref>) { -! CHECK: fir.do_loop -! CHECK: acc.yield -! CHECK-NEXT: }{{$}} -! CHECK: acc.yield -! CHECK-NEXT: }{{$}} +! TODO: will be updated after lowering change in privatization to MLIR +! !$acc parallel loop private(a) firstprivate(b) +! DO i = 1, n +! a(i) = b(i) +! END DO + +! TODO: acc.parallel firstprivate(%[[B]] : !fir.ref>) private(%[[A]] : !fir.ref>) { +! TODO: acc.loop private(%[[A]] : !fir.ref>) { +! TODO: fir.do_loop +! TODO: acc.yield +! TODO-NEXT: }{{$}} +! TODO: acc.yield +! TODO-NEXT: }{{$}} !$acc parallel loop seq DO i = 1, n diff --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90 --- a/flang/test/Lower/OpenACC/acc-parallel.f90 +++ b/flang/test/Lower/OpenACC/acc-parallel.f90 @@ -288,11 +288,12 @@ !CHECK: acc.detach accPtr(%[[ATTACH_D]] : !fir.ptr) {dataClause = 10 : i64, name = "d"} !CHECK: acc.detach accPtr(%[[ATTACH_E]] : !fir.ptr) {dataClause = 10 : i64, name = "e"} - !$acc parallel private(a) firstprivate(b) private(c) - !$acc end parallel +! TODO: will be updated after lowering change in privatization to MLIR +! !$acc parallel private(a) firstprivate(b) private(c) +! !$acc end parallel -!CHECK: acc.parallel firstprivate(%[[B]] : !fir.ref>) private(%[[A]], %[[C]] : !fir.ref>, !fir.ref>) { -!CHECK: acc.yield -!CHECK-NEXT: }{{$}} +!TODO: acc.parallel firstprivate(%[[B]] : !fir.ref>) private(%[[A]], %[[C]] : !fir.ref>, !fir.ref>) { +!TODO: acc.yield +!TODO-NEXT: }{{$}} end subroutine acc_parallel diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -636,7 +636,8 @@ UnitAttr:$selfAttr, OptionalAttr:$reductionOp, Variadic:$reductionOperands, - Variadic:$gangPrivateOperands, + Variadic:$gangPrivateOperands, + OptionalAttr:$privatizations, Variadic:$gangFirstPrivateOperands, Variadic:$dataClauseOperands, OptionalAttr:$defaultAttr); @@ -659,7 +660,9 @@ type($gangFirstPrivateOperands) `)` | `num_gangs` `(` $numGangs `:` type($numGangs) `)` | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` - | `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)` + | `private` `(` custom( + $gangPrivateOperands, type($gangPrivateOperands), $privatizations) + `)` | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -436,6 +436,43 @@ return success(); } +//===----------------------------------------------------------------------===// +// Custom parser and printer verifier for private clause +//===----------------------------------------------------------------------===// + +static ParseResult parsePrivatizationList( + mlir::OpAsmParser &parser, + llvm::SmallVectorImpl &operands, + llvm::SmallVectorImpl &types, mlir::ArrayAttr &privatizationSymbols) { + llvm::SmallVector privatizationVec; + if (failed(parser.parseCommaSeparatedList([&]() { + if (parser.parseAttribute(privatizationVec.emplace_back()) || + parser.parseArrow() || + parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + return success(); + }))) + return failure(); + llvm::SmallVector privatizations(privatizationVec.begin(), + privatizationVec.end()); + privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations); + return success(); +} + +static void +printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange privateOperands, + mlir::TypeRange privateTypes, + std::optional privatizations) { + for (unsigned i = 0, e = privatizations->size(); i < e; ++i) { + if (i != 0) + p << ", "; + p << (*privatizations)[i] << " -> " << privateOperands[i] << " : " + << privateOperands[i].getType(); + } +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -455,6 +492,45 @@ return success(); } +static LogicalResult +checkPrivatizationList(Operation *op, + std::optional privatizations, + mlir::OperandRange privateOperands) { + if (!privateOperands.empty()) { + if (!privatizations || privatizations->size() != privateOperands.size()) + return op->emitOpError() << "expected as many privatizations symbol " + "reference as private operands"; + } else { + if (privatizations) + return op->emitOpError() << "unexpected privatizations symbol reference"; + return success(); + } + + llvm::DenseSet privates; + for (auto args : llvm::zip(privateOperands, *privatizations)) { + mlir::Value privateOperand = std::get<0>(args); + + if (!privates.insert(privateOperand).second) + return op->emitOpError() << "private operand appears more than once"; + + mlir::Type varType = privateOperand.getType(); + auto symbolRef = std::get<1>(args).cast(); + auto decl = + SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + if (!decl) + return op->emitOpError() << "expected symbol reference " << symbolRef + << " to point to a private declaration"; + + if (decl.getType() && decl.getType() != varType) + return op->emitOpError() + << "expected private (" << varType + << ") to be the same type as private declaration (" + << decl.getType() << ")"; + } + + return success(); +} + unsigned ParallelOp::getNumDataOperands() { return getReductionOperands().size() + getGangPrivateOperands().size() + getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); @@ -471,6 +547,9 @@ } LogicalResult acc::ParallelOp::verify() { + if (failed(checkPrivatizationList(*this, getPrivatizations(), + getGangPrivateOperands()))) + return failure(); return checkDataOperands(*this, getDataClauseOperands()); } diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -114,6 +114,16 @@ // ----- +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> { %lb = arith.constant 0 : index %st = arith.constant 1 : index @@ -126,7 +136,7 @@ %pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32> %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32> acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { - acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) { + acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %c : memref<10xf32>) { acc.loop gang { scf.for %x = %lb to %c10 step %st { acc.loop worker { @@ -168,7 +178,7 @@ // CHECK-NEXT: [[NUMGANG:%.*]] = arith.constant 10 : i64 // CHECK-NEXT: [[NUMWORKERS:%.*]] = arith.constant 10 : i64 // CHECK: acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) { -// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private([[ARG2]] : memref<10xf32>) { +// CHECK-NEXT: acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> [[ARG2]] : memref<10xf32>) { // CHECK-NEXT: acc.loop gang { // CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] { // CHECK-NEXT: acc.loop worker { @@ -358,6 +368,26 @@ // ----- +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { %i64value = arith.constant 1 : i64 %i32value = arith.constant 1 : i32 @@ -394,7 +424,7 @@ } acc.parallel vector_length(%idxValue: index) { } - acc.parallel private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) { + acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(%b: memref<10xf32>) { } acc.parallel { } attributes {defaultAttr = #acc} @@ -445,7 +475,7 @@ // CHECK-NEXT: } // CHECK: acc.parallel vector_length([[IDXVALUE]] : index) { // CHECK-NEXT: } -// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) { +// CHECK: acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) { // CHECK-NEXT: } // CHECK: acc.parallel { // CHECK-NEXT: } attributes {defaultAttr = #acc}