diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -634,8 +634,8 @@ Optional:$ifCond, Optional:$selfCond, UnitAttr:$selfAttr, - OptionalAttr:$reductionOp, Variadic:$reductionOperands, + OptionalAttr:$reductions, Variadic:$gangPrivateOperands, OptionalAttr:$privatizations, Variadic:$gangFirstPrivateOperands, @@ -660,14 +660,16 @@ type($gangFirstPrivateOperands) `)` | `num_gangs` `(` $numGangs `:` type($numGangs) `)` | `num_workers` `(` $numWorkers `:` type($numWorkers) `)` - | `private` `(` custom( + | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` | `vector_length` `(` $vectorLength `:` type($vectorLength) `)` | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` - | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` + | `reduction` `(` custom( + $reductionOperands, type($reductionOperands), $reductions) + `)` ) $region attr-dict-with-keyword }]; @@ -726,7 +728,7 @@ | `async` `(` $async `:` type($async) `)` | `firstprivate` `(` $gangFirstPrivateOperands `:` type($gangFirstPrivateOperands) `)` - | `private` `(` custom( + | `private` `(` custom( $gangPrivateOperands, type($gangPrivateOperands), $privatizations) `)` | `wait` `(` $waitOperands `:` type($waitOperands) `)` @@ -1060,7 +1062,7 @@ `gang` `` custom($gangNum, type($gangNum), $gangStatic, type($gangStatic), $hasGang) | `worker` `` custom($workerNum, type($workerNum), $hasWorker) | `vector` `` custom($vectorLength, type($vectorLength), $hasVector) - | `private` `(` custom( + | `private` `(` custom( $privateOperands, type($privateOperands), $privatizations) `)` | `tile` `(` $tileOperands `:` type($tileOperands) `)` diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -440,13 +440,13 @@ // Custom parser and printer verifier for private clause //===----------------------------------------------------------------------===// -static ParseResult parsePrivatizationList( +static ParseResult parseSymOperandList( mlir::OpAsmParser &parser, llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &types, mlir::ArrayAttr &privatizationSymbols) { - llvm::SmallVector privatizationVec; + llvm::SmallVectorImpl &types, mlir::ArrayAttr &symbols) { + llvm::SmallVector attributes; if (failed(parser.parseCommaSeparatedList([&]() { - if (parser.parseAttribute(privatizationVec.emplace_back()) || + if (parser.parseAttribute(attributes.emplace_back()) || parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) @@ -454,22 +454,21 @@ return success(); }))) return failure(); - llvm::SmallVector privatizations(privatizationVec.begin(), - privatizationVec.end()); - privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations); + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + symbols = ArrayAttr::get(parser.getContext(), arrayAttr); return success(); } -static void -printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op, - mlir::OperandRange privateOperands, - mlir::TypeRange privateTypes, - std::optional privatizations) { - for (unsigned i = 0, e = privatizations->size(); i < e; ++i) { +static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::OperandRange operands, + mlir::TypeRange types, + std::optional attributes) { + for (unsigned i = 0, e = attributes->size(); i < e; ++i) { if (i != 0) p << ", "; - p << (*privatizations)[i] << " -> " << privateOperands[i] << " : " - << privateOperands[i].getType(); + p << (*attributes)[i] << " -> " << operands[i] << " : " + << operands[i].getType(); } } @@ -492,40 +491,43 @@ return success(); } +template static LogicalResult -checkPrivatizationList(Operation *op, - std::optional privatizations, - mlir::OperandRange privateOperands) { - if (!privateOperands.empty()) { - if (!privatizations || privatizations->size() != privateOperands.size()) - return op->emitOpError() << "expected as many privatizations symbol " - "reference as private operands"; +checkSymOperandList(Operation *op, std::optional attributes, + mlir::OperandRange operands, llvm::StringRef operandName, + llvm::StringRef symbolName) { + if (!operands.empty()) { + if (!attributes || attributes->size() != operands.size()) + return op->emitOpError() + << "expected as many " << symbolName << " symbol reference as " + << operandName << " operands"; } else { - if (privatizations) - return op->emitOpError() << "unexpected privatizations symbol reference"; + if (attributes) + return op->emitOpError() + << "unexpected " << symbolName << " symbol reference"; return success(); } - llvm::DenseSet privates; - for (auto args : llvm::zip(privateOperands, *privatizations)) { - mlir::Value privateOperand = std::get<0>(args); + llvm::DenseSet set; + for (auto args : llvm::zip(operands, *attributes)) { + mlir::Value operand = std::get<0>(args); - if (!privates.insert(privateOperand).second) - return op->emitOpError() << "private operand appears more than once"; + if (!set.insert(operand).second) + return op->emitOpError() + << operandName << " operand appears more than once"; - mlir::Type varType = privateOperand.getType(); + mlir::Type varType = operand.getType(); auto symbolRef = std::get<1>(args).cast(); - auto decl = - SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + auto decl = SymbolTable::lookupNearestSymbolFrom(op, symbolRef); if (!decl) - return op->emitOpError() << "expected symbol reference " << symbolRef - << " to point to a private declaration"; + return op->emitOpError() + << "expected symbol reference " << symbolRef << " to point to a " + << operandName << " declaration"; if (decl.getType() && decl.getType() != varType) return op->emitOpError() - << "expected private (" << varType - << ") to be the same type as private declaration (" - << decl.getType() << ")"; + << "expected private (" << varType << ") to be the same type as " + << operandName << " declaration (" << decl.getType() << ")"; } return success(); @@ -547,8 +549,13 @@ } LogicalResult acc::ParallelOp::verify() { - if (failed(checkPrivatizationList(*this, getPrivatizations(), - getGangPrivateOperands()))) + if (failed(checkSymOperandList( + *this, getPrivatizations(), getGangPrivateOperands(), "private", + "privatizations"))) + return failure(); + if (failed(checkSymOperandList( + *this, getReductions(), getReductionOperands(), "reduction", + "reductions"))) return failure(); return checkDataOperands(*this, getDataClauseOperands()); } @@ -726,8 +733,9 @@ if (getSeq() && (getHasGang() || getHasWorker() || getHasVector())) return emitError("gang, worker or vector cannot appear with the seq attr"); - if (failed(checkPrivatizationList(*this, getPrivatizations(), - getPrivateOperands()))) + if (failed(checkSymOperandList( + *this, getPrivatizations(), getPrivateOperands(), "private", + "privatizations"))) return failure(); // Check non-empty body(). diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1405,3 +1405,13 @@ // CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64 // CHECK: acc.yield %[[RES]] : i64 // CHECK: } + +func.func @acc_reduc_test(%a : i64) -> () { + acc.parallel reduction(@reduction_add_i64 -> %a : i64) { + } + return +} + +// CHECK-LABEL: func.func @acc_reduc_test( +// CHECK-SAME: %[[ARG0:.*]]: i64) +// CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64)