diff --git a/flang/lib/Evaluate/fold-implementation.h b/flang/lib/Evaluate/fold-implementation.h --- a/flang/lib/Evaluate/fold-implementation.h +++ b/flang/lib/Evaluate/fold-implementation.h @@ -1381,6 +1381,28 @@ return result; } +template +bool ShapesMatch(FoldingContext &context, + const ArrayConstructor &leftArrConst, + const ArrayConstructor &rightArrConst) { + auto rightIter{rightArrConst.begin()}; + for (auto &leftValue : leftArrConst) { + CHECK(rightIter != rightArrConst.end()); + auto &leftExpr{std::get>(leftValue.u)}; + auto &rightExpr{std::get>(rightIter->u)}; + if (leftExpr.Rank() != rightExpr.Rank()) { + return false; + } + std::optional leftShape{GetShape(context, leftExpr)}; + std::optional rightShape{GetShape(context, rightExpr)}; + if (!leftShape || !rightShape || *leftShape != *rightShape) { + return false; + } + ++rightIter; + } + return true; +} + // array * array case template auto MapOperation(FoldingContext &context, @@ -1391,11 +1413,14 @@ auto result{ArrayConstructorFromMold(leftValues, std::move(length))}; auto &leftArrConst{std::get>(leftValues.u)}; if constexpr (common::HasMember) { - common::visit( - [&](auto &&kindExpr) { + bool mapped{common::visit( + [&](auto &&kindExpr) -> bool { using kindType = ResultType; auto &rightArrConst{std::get>(kindExpr.u)}; + if (!ShapesMatch(context, leftArrConst, rightArrConst)) { + return false; + } auto rightIter{rightArrConst.begin()}; for (auto &leftValue : leftArrConst) { CHECK(rightIter != rightArrConst.end()); @@ -1405,10 +1430,17 @@ f(std::move(leftScalar), Expr{std::move(rightScalar)}))); ++rightIter; } + return true; }, - std::move(rightValues.u)); + std::move(rightValues.u))}; + if (!mapped) { + return std::nullopt; + } } else { auto &rightArrConst{std::get>(rightValues.u)}; + if (!ShapesMatch(context, leftArrConst, rightArrConst)) { + return std::nullopt; + } auto rightIter{rightArrConst.begin()}; for (auto &leftValue : leftArrConst) { CHECK(rightIter != rightArrConst.end()); diff --git a/flang/test/Evaluate/rewrite01.f90 b/flang/test/Evaluate/rewrite01.f90 --- a/flang/test/Evaluate/rewrite01.f90 +++ b/flang/test/Evaluate/rewrite01.f90 @@ -196,7 +196,9 @@ end subroutine !CHECK-LABEL: array_constructor -subroutine array_constructor() +subroutine array_constructor(a, u, v, w, x, y, z) + real :: a(4) + integer :: u(:), v(1), w(2), x(4), y(4), z(2, 2) interface function return_allocatable() real, allocatable :: return_allocatable(:) @@ -204,6 +206,28 @@ end interface !CHECK: PRINT *, size([REAL(4)::return_allocatable(),return_allocatable()]) print *, size([return_allocatable(), return_allocatable()]) + !CHECK: PRINT *, [INTEGER(4)::x+y] + print *, (/x/) + (/y/) + !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::z] + print *, (/x/) + (/z/) + !CHECK: PRINT *, [INTEGER(4)::x+y,x+y] + print *, (/x, x/) + (/y, y/) + !CHECK: PRINT *, [INTEGER(4)::x,x]+[INTEGER(4)::x,z] + print *, (/x, x/) + (/x, z/) + !CHECK: PRINT *, [INTEGER(4)::x,w,w]+[INTEGER(4)::w,w,x] + print *, (/x, w, w/) + (/w, w, x/) + !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::1_4,2_4,3_4,4_4] + print *, (/x/) + (/1, 2, 3, 4/) + !CHECK: PRINT *, [INTEGER(4)::v]+[INTEGER(4)::1_4] + print *, (/v/) + (/1/) + !CHECK: PRINT *, [INTEGER(4)::x]+[INTEGER(4)::u] + print *, (/x/) + (/u/) + !CHECK: PRINT *, [INTEGER(4)::u]+[INTEGER(4)::u] + print *, (/u/) + (/u/) + !CHECK: PRINT *, [REAL(4)::a**x] + print *, (/a/) ** (/x/) + !CHECK: PRINT *, [REAL(4)::a]**[INTEGER(4)::z] + print *, (/a/) ** (/z/) end subroutine !CHECK-LABEL: array_ctor_implied_do_index diff --git a/flang/test/Lower/array-expression.f90 b/flang/test/Lower/array-expression.f90 --- a/flang/test/Lower/array-expression.f90 +++ b/flang/test/Lower/array-expression.f90 @@ -1158,4 +1158,108 @@ print *, scan(c1, c2) end subroutine +! Check that the expression is folded, with the first operation being an add +! between x and y, resulting in a new temporary array. +! +! CHECK-LABEL: func @_QPtest20a( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {{.*}}, %[[ARG1:.*]]: !fir.ref> {{.*}}, %[[ARG2:.*]]: !fir.ref> +! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32> +! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) { +! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: %[[YI:.*]] = fir.array_fetch %[[Y]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: %[[ADD:.*]] = arith.addi %[[XI]], %[[YI]] : i32 +! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[ADD]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32> +! CHECK: } +subroutine test20a(x, y, z) + integer :: x(4), y(4), z(4) + + z = (/x/) + (/y/) +end subroutine + +! Check that the expression is not folded, with the first operations being +! array constructions from x and y. +! +! CHECK-LABEL: func @_QPtest20b( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {{.*}}, %[[ARG1:.*]]: !fir.ref> {{.*}}, %[[ARG2:.*]]: !fir.ref> +! CHECK: %[[Z:.*]] = fir.array_load %[[ARG2]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[TEMP:.*]] = fir.allocmem !fir.array<4xi32> +! CHECK: %[[TEMP2:.*]] = fir.array_load %[[TEMP]]({{.*}}) : (!fir.heap>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP3:.*]] = %[[TEMP2]]) -> (!fir.array<4xi32>) { +! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: {{.*}} = fir.array_update %[[TEMP3]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32> +! CHECK: } +! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref>, !fir.shape<2>) -> !fir.array<2x2xi32> +! CHECK: %[[TEMP4:.*]] = fir.allocmem !fir.array<2x2xi32> +! CHECK: %[[TEMP5:.*]] = fir.array_load %[[TEMP4]]({{.*}}) : (!fir.heap>, !fir.shape<2>) -> !fir.array<2x2xi32> +! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP6:.*]] = %[[TEMP5]]) -> (!fir.array<2x2xi32>) { +! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP7:.*]] = %[[TEMP6]]) -> (!fir.array<2x2xi32>) { +! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32 +! CHECK: {{.*}} = fir.array_update %[[TEMP7]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32> +! CHECK: } +! CHECK: } +subroutine test20b(x, y, z) + integer :: x(4), y(2, 2), z(4) + + z = (/x/) + (/y/) +end subroutine + +! CHECK-LABEL: func @_QPtest20c( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref> {{.*}}, %[[ARG1:.*]]: !fir.ref> {{.*}} + +! (/x/) +! CHECK: %[[X:.*]] = fir.array_load %[[ARG0]]({{.*}}) : (!fir.ref>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: %[[ACX_MEM:.*]] = fir.allocmem !fir.array<4xi32> +! CHECK: %[[ACX:.*]] = fir.array_load %[[ACX_MEM]]({{.*}}) : (!fir.heap>, !fir.shape<1>) -> !fir.array<4xi32> +! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACX]]) -> (!fir.array<4xi32>) { +! CHECK: %[[XI:.*]] = fir.array_fetch %[[X]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[XI]], %[[I]] : (!fir.array<4xi32>, i32, index) -> !fir.array<4xi32> +! CHECK: } +! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACX_MEM2:.*]], %{{.*}} : (!fir.heap>, index) -> !fir.ref +! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref) -> !fir.ref +! CHECK: %[[T2:.*]] = fir.convert %[[ACX_MEM]] : (!fir.heap>) -> !fir.ref +! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}}) +! CHECK: %[[ACX2:.*]] = fir.array_load %[[ACX_MEM2]]({{.*}}) : (!fir.heap>, !fir.shape<1>) -> !fir.array<4xi32> + +! (/y/) +! CHECK: %[[Y:.*]] = fir.array_load %[[ARG1]]({{.*}}) : (!fir.ref>, !fir.shape<2>) -> !fir.array<2x2xi32> +! CHECK: %[[ACY_MEM:.*]] = fir.allocmem !fir.array<2x2xi32> +! CHECK: %[[ACY:.*]] = fir.array_load %[[ACY_MEM]]({{.*}}) : (!fir.heap>, !fir.shape<2>) -> !fir.array<2x2xi32> +! CHECK: {{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[ACY]]) -> (!fir.array<2x2xi32>) { +! CHECK: {{.*}} = fir.do_loop %[[J:.*]] = {{.*}} iter_args(%[[TEMP2:.*]] = %[[TEMP]]) -> (!fir.array<2x2xi32>) { +! CHECK: %[[YJI:.*]] = fir.array_fetch %[[Y]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, index, index) -> i32 +! CHECK: {{.*}} = fir.array_update %[[TEMP2]], %[[YJI]], %[[J]], %[[I]] : (!fir.array<2x2xi32>, i32, index, index) -> !fir.array<2x2xi32> +! CHECK: } +! CHECK: } +! CHECK: %[[T:.*]] = fir.coordinate_of %[[ACY_MEM2:.*]], {{.*}} : (!fir.heap>, index) -> !fir.ref +! CHECK: %[[T1:.*]] = fir.convert %[[T]] : (!fir.ref) -> !fir.ref +! CHECK: %[[T2:.*]] = fir.convert %[[ACY_MEM]] : (!fir.heap>) -> !fir.ref +! CHECK: fir.call @llvm.memcpy.p0.p0.i64(%[[T1]], %[[T2]], {{.*}}) +! CHECK: %[[ACY2:.*]] = fir.array_load %[[ACY_MEM2]]({{.*}}) : (!fir.heap>, !fir.shape<1>) -> !fir.array<4xi32> + +! (/x/) /= (/y/) +! CHECK: %[[RES_MEM:.*]] = fir.allocmem !fir.array<4x!fir.logical<4>> +! CHECK: %[[RES:.*]] = fir.array_load %[[RES_MEM]]({{.*}}) : (!fir.heap>>, !fir.shape<1>) -> !fir.array<4x!fir.logical<4>> +! CHECK: %{{.*}} = fir.do_loop %[[I:.*]] = {{.*}} iter_args(%[[TEMP:.*]] = %[[RES]]) -> (!fir.array<4x!fir.logical<4>>) { +! CHECK: %[[XI:.*]] = fir.array_fetch %[[ACX2]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: %[[YI:.*]] = fir.array_fetch %[[ACY2]], %[[I]] : (!fir.array<4xi32>, index) -> i32 +! CHECK: %[[T1:.*]] = arith.cmpi ne, %[[XI]], %[[YI]] : i32 +! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (i1) -> !fir.logical<4> +! CHECK: {{.*}} = fir.array_update %[[TEMP]], %[[T2]], %[[I]] : (!fir.array<4x!fir.logical<4>>, !fir.logical<4>, index) -> !fir.array<4x!fir.logical<4>> +! CHECK: } + +! any((/x/) /= (/y/)) +! CHECK: %[[T1:.*]] = fir.embox %[[RES_MEM]]({{.*}}) : (!fir.heap>>, !fir.shape<1>) -> !fir.box>> +! CHECK: %[[T2:.*]] = fir.convert %[[T1]] : (!fir.box>>) -> !fir.box +! CHECK: fir.call @_FortranAAny(%[[T2]], {{.*}}){{.*}} : (!fir.box, !fir.ref, i32, i32) -> i1 +subroutine test20c(x, y) + integer :: x(4), y(2, 2) + + if (any((/x/) /= (/y/))) print *, "different" +end subroutine + ! CHECK: func private @_QPbar(