diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -4564,6 +4564,14 @@ return [=](IterSpace iters) -> ExtValue { return placeScalarValueInMemory(builder, loc, cc(iters), storageType); }; + } else if (isArray(x)) { + // An array reference is needed, but the indices used in its path must + // still be retrieved by value. + assert(!nextPathSemant && "Next path semantics already set!"); + nextPathSemant = ConstituentSemantics::RefTransparent; + CC cc = genarr(x); + assert(!nextPathSemant && "Next path semantics wasn't used!"); + return cc; } } return genarr(x); @@ -6617,9 +6625,9 @@ mlir::IndexType idxTy = builder.getIndexType(); mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); bool atBase = true; - auto saveSemant = semant; - if (isProjectedCopyInCopyOut()) - semant = ConstituentSemantics::RefTransparent; + PushSemantics(isProjectedCopyInCopyOut() + ? ConstituentSemantics::RefTransparent + : nextPathSemantics()); unsigned index = 0; for (const auto &v : llvm::reverse(revPath)) { std::visit( @@ -6728,7 +6736,6 @@ atBase = false; ++index; } - semant = saveSemant; ty = fir::unwrapSequenceType(ty); components.applied = true; return ty; @@ -7131,6 +7138,18 @@ return semant == ConstituentSemantics::ByValueArg; } + /// Semantics to use when lowering the next array path. + /// If no value was set, the path uses the same semantics as the array. + inline ConstituentSemantics nextPathSemantics() { + if (nextPathSemant) { + ConstituentSemantics sema = nextPathSemant.value(); + nextPathSemant.reset(); + return sema; + } + + return semant; + } + /// Can the loops over the expression be unordered? inline bool isUnordered() const { return unordered; } @@ -7179,6 +7198,7 @@ Fortran::lower::ExplicitIterSpace *explicitSpace = nullptr; Fortran::lower::ImplicitIterSpace *implicitSpace = nullptr; ConstituentSemantics semant = ConstituentSemantics::RefTransparent; + std::optional nextPathSemant; /// `lbounds`, `ubounds` are used in POINTER value assignments, which may only /// occur in an explicit iteration space. std::optional> lbounds; diff --git a/flang/test/Lower/array-elemental-calls-3.f90 b/flang/test/Lower/array-elemental-calls-3.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/array-elemental-calls-3.f90 @@ -0,0 +1,57 @@ +! RUN: bbc -o - -emit-fir %s | FileCheck %s + +! Test lowering of elemental calls with array arguments that use array +! elements as indices. +! As reported in issue #62981, wrong code was being generated in this case. + +module test_ops + implicit none + interface + integer elemental function elem_func_i(i) + integer, intent(in) :: i + end function + real elemental function elem_func_r(r) + real, intent(in) :: r + end function + end interface + + integer :: a(3), b(3), v(3), i, j, k, l + real :: x(2), y(2), u + +contains +! CHECK-LABEL: func @_QMtest_opsPcheck_array_elems_as_indices() { +subroutine check_array_elems_as_indices() +! CHECK: %[[A_ADDR:.*]] = fir.address_of(@_QMtest_opsEa) : !fir.ref> +! CHECK: %[[V_ADDR:.*]] = fir.address_of(@_QMtest_opsEv) : !fir.ref> +! CHECK: %[[V:.*]] = fir.array_load %[[V_ADDR]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<3xi32> +! CHECK: %[[A:.*]] = fir.array_load %[[A_ADDR]](%{{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<3xi32> +! CHECK: fir.do_loop + forall (i=1:3) +! CHECK: %{{.*}} = fir.array_fetch %[[V]], %{{.*}} : (!fir.array<3xi32>, index) -> i32 +! CHECK: fir.do_loop +! CHECK: %[[ELEM:.*]] = fir.array_access %[[A]], %{{.*}} : (!fir.array<3xi32>, index) -> !fir.ref +! CHECK: %{{.*}} = fir.call @_QPelem_func_i(%[[ELEM]]){{.*}} : (!fir.ref) -> i32 + b(i:i) = elem_func_i(a(v(i):v(i))) + end forall +end subroutine + +! CHECK-LABEL: func @_QMtest_opsPcheck_not_assert() { +subroutine check_not_assert() + ! Implicit path. + b = 10 + elem_func_i(a) + + ! Expression as argument, instead of variable. + forall (i=1:3) + b(i:i) = elem_func_i(a(i:i) + a(i:i)) + end forall + + ! Nested elemental function calls. + y = elem_func_r(cos(x)) + y = elem_func_r(cos(x) + u) + + ! Array constructors as elemental function arguments. + y = atan2( (/ (real(i, 4), i = 1, 2) /), + real( (/ (i, i = j, k, l) /), 4) ) +end subroutine + +end module