diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -13,10 +13,12 @@ #include "flang/Lower/OpenACC.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" +#include "flang/Lower/ConvertType.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/Support/Utils.h" #include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/Todo.h" @@ -712,11 +714,17 @@ loc, ty, builder.getFloatAttr(ty, getReductionInitValue(op, ty))); - } else { - if (auto floatTy = mlir::dyn_cast_or_null(ty)) - return builder.create( - loc, ty, - builder.getFloatAttr(ty, getReductionInitValue(op, ty))); + } else if (auto floatTy = mlir::dyn_cast_or_null(ty)) { + return builder.create( + loc, ty, + builder.getFloatAttr(ty, getReductionInitValue(op, ty))); + } else if (auto cmplxTy = mlir::dyn_cast_or_null(ty)) { + mlir::Type floatTy = + Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind()); + mlir::Value init = builder.createRealConstant( + loc, floatTy, getReductionInitValue(op, cmplxTy)); + return fir::factory::Complex{builder, loc}.createComplex(cmplxTy.getFKind(), + init, init); } if (auto refTy = mlir::dyn_cast(ty)) { if (auto seqTy = mlir::dyn_cast(refTy.getEleTy())) { @@ -738,7 +746,7 @@ } } - TODO(loc, "reduction type"); + llvm::report_fatal_error("Unsupported OpenACC reduction type"); } template @@ -808,6 +816,8 @@ return builder.create(loc, value1, value2); if (mlir::isa(ty)) return builder.create(loc, value1, value2); + if (auto cmplxTy = mlir::dyn_cast_or_null(ty)) + return builder.create(loc, value1, value2); TODO(loc, "reduction add type"); } diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90 --- a/flang/test/Lower/OpenACC/acc-reduction.f90 +++ b/flang/test/Lower/OpenACC/acc-reduction.f90 @@ -2,6 +2,19 @@ ! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s +! CHECK-LABEL: acc.reduction.recipe @reduction_add_z32 : !fir.complex<4> reduction_operator init { +! CHECK: ^bb0(%{{.*}}: !fir.complex<4>): +! CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +! CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<4> +! CHECK: %[[UNDEF1:.*]] = fir.insert_value %[[UNDEF]], %[[CST]], [0 : index] : (!fir.complex<4>, f32) -> !fir.complex<4> +! CHECK: %[[UNDEF2:.*]] = fir.insert_value %[[UNDEF1]], %[[CST]], [1 : index] : (!fir.complex<4>, f32) -> !fir.complex<4> +! CHECK: acc.yield %[[UNDEF2]] : !fir.complex<4> +! CHECK: } combiner { +! CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<4>, %[[ARG1:.*]]: !fir.complex<4>): +! CHECK: %[[COMBINED:.*]] = fir.addc %[[ARG0]], %[[ARG1]] : !fir.complex<4> +! CHECK: acc.yield %[[COMBINED]] : !fir.complex<4> +! CHECK: } + ! CHECK-LABEL: acc.reduction.recipe @reduction_neqv_l32 : !fir.logical<4> reduction_operator init { ! CHECK: ^bb0(%{{.*}}: !fir.logical<4>): ! CHECK: %[[CST:.*]] = arith.constant false @@ -729,3 +742,13 @@ ! CHECK-LABEL: func.func @_QPacc_reduction_neqv() ! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref>) -> !fir.ref> {name = "l"} ! CHECK: acc.parallel reduction(@reduction_neqv_l32 -> %[[RED]] : !fir.ref>) + +subroutine acc_reduction_add_cmplx() + complex :: c + !$acc parallel reduction(+:c) + !$acc end parallel +end subroutine + +! CHECK-LABEL: func.func @_QPacc_reduction_add_cmplx() +! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref>) -> !fir.ref> {name = "c"} +! CHECK: acc.parallel reduction(@reduction_add_z32 -> %[[RED]] : !fir.ref>)