diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -705,8 +705,8 @@ Optional:$ifCond, Optional:$selfCond, UnitAttr:$selfAttr, - OptionalAttr:$reductionOp, Variadic:$reductionOperands, + OptionalAttr:$reductionRecipes, Variadic:$gangPrivateOperands, OptionalAttr:$privatizations, Variadic:$gangFirstPrivateOperands, @@ -735,7 +735,9 @@ | `wait` `(` $waitOperands `:` type($waitOperands) `)` | `self` `(` $selfCond `)` | `if` `(` $ifCond `)` - | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)` + | `reduction` `(` custom( + $reductionOperands, type($reductionOperands), $reductionRecipes) + `)` ) $region attr-dict-with-keyword }]; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -580,7 +580,10 @@ } LogicalResult acc::SerialOp::verify() { - + if (failed(checkSymOperandList( + *this, getReductionRecipes(), getReductionOperands(), "reduction", + "reductions"))) + return failure(); return checkDataOperands(*this, getDataClauseOperands()); } diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1440,3 +1440,25 @@ // CHECK-LABEL: func.func @acc_reduc_test( // CHECK-SAME: %[[ARG0:.*]]: i64) // CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64) + +// ----- + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +func.func @acc_reduc_test(%a : i64) -> () { + acc.serial reduction(@reduction_add_i64 -> %a : i64) { + } + return +} + +// CHECK-LABEL: func.func @acc_reduc_test( +// CHECK-SAME: %[[ARG0:.*]]: i64) +// CHECK: acc.serial reduction(@reduction_add_i64 -> %[[ARG0]] : i64)