diff --git a/flang/lib/Lower/ConvertArrayConstructor.cpp b/flang/lib/Lower/ConvertArrayConstructor.cpp --- a/flang/lib/Lower/ConvertArrayConstructor.cpp +++ b/flang/lib/Lower/ConvertArrayConstructor.cpp @@ -201,7 +201,89 @@ }; using InlinedTempStrategy = InlinedTempStrategyImpl; -// TODO: add and implement AsElementalStrategy. +/// Class that implements the "as function of the indices" lowering strategy. +/// It will lower [(scalar_expr(i), i=l,u,s)] to: +/// ``` +/// %extent = max((%u-%l+1)/%s, 0) +/// %shape = fir.shape %extent +/// %elem = hlfir.elemental %shape { +/// ^bb0(%pos:index): +/// %i = %l+(%i-1)*%s +/// %value = scalar_expr(%i) +/// hlfir.yield_element %value +/// } +/// ``` +/// That way, no temporary is created in lowering, and if the array constructor +/// is part of a more complex elemental expression, or an assignment, it will be +/// trivial to "inline" it in the expression or assignment loops if allowed by +/// alias analysis. +/// This lowering is however only possible for the form of array constructors as +/// in the illustration above. It could be extended to deeper independent +/// implied-do nest and wrapped in an hlfir.reshape to a rank 1 array. But this +/// op does not exist yet, so this is left for the future if it appears +/// profitable. +class AsElementalStrategy { +public: + /// The constructor only gathers the operands to create the hlfir.elemental. + AsElementalStrategy(mlir::Location loc, fir::FirOpBuilder &builder, + fir::SequenceType declaredType, mlir::Value extent, + llvm::ArrayRef lengths) + : shape{builder.genShape(loc, {extent})}, + lengthParams{lengths.begin(), lengths.end()}, exprType{getExprType( + declaredType)} {} + + static hlfir::ExprType getExprType(fir::SequenceType declaredType) { + // Note: 7.8 point 4: the dynamic type of an array constructor is its static + // type, it is not polymorphic. + return hlfir::ExprType::get(declaredType.getContext(), + declaredType.getShape(), + declaredType.getEleTy(), + /*isPolymorphic=*/false); + } + + /// Create the hlfir.elemental and compute the ac-implied-do-index value + /// given the lower bound and stride (compute "%i" in the illustration above). + mlir::Value startImpliedDo(mlir::Location loc, fir::FirOpBuilder &builder, + mlir::Value lower, mlir::Value upper, + mlir::Value stride) { + assert(!elementalOp && "expected only one implied-do"); + mlir::Value one = + builder.createIntegerConstant(loc, builder.getIndexType(), 1); + elementalOp = + builder.create(loc, exprType, shape, lengthParams); + builder.setInsertionPointToStart(elementalOp.getBody()); + // implied-do-index = lower+((i-1)*stride) + mlir::Value diff = builder.create( + loc, elementalOp.getIndices()[0], one); + mlir::Value mul = builder.create(loc, diff, stride); + mlir::Value add = builder.create(loc, lower, mul); + return add; + } + + /// Create the elemental hlfir.yield_element with the scalar ac-value. + void pushValue(mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity value) { + assert(value.isScalar() && "cannot use hlfir.elemental with array values"); + assert(elementalOp && "array constructor must contain an outer implied-do"); + mlir::Value elementResult = value; + if (fir::isa_trivial(elementResult.getType())) + elementResult = + builder.createConvert(loc, exprType.getElementType(), elementResult); + builder.create(loc, elementResult); + } + + /// Return the created hlfir.elemental. + hlfir::Entity finishArrayCtorLowering(mlir::Location loc, + fir::FirOpBuilder &builder) { + return hlfir::Entity{elementalOp}; + } + +private: + mlir::Value shape; + llvm::SmallVector lengthParams; + hlfir::ExprType exprType; + hlfir::ElementalOp elementalOp{}; +}; // TODO: add and implement RuntimeTempStrategy. @@ -237,7 +319,9 @@ } private: - std::variant implVariant; + std::variant + implVariant; }; } // namespace @@ -312,17 +396,20 @@ struct ArrayCtorAnalysis { template ArrayCtorAnalysis( + Fortran::evaluate::FoldingContext &, const Fortran::evaluate::ArrayConstructor &arrayCtorExpr); // Can the array constructor easily be rewritten into an hlfir.elemental ? - bool isSingleImpliedDoWithOneScalarExpr() const { + bool isSingleImpliedDoWithOneScalarPureExpr() const { return !anyArrayExpr && isPerfectLoopNest && - innerNumberOfExprIfPrefectNest == 1 && depthIfPerfectLoopNest == 1; + innerNumberOfExprIfPrefectNest == 1 && depthIfPerfectLoopNest == 1 && + innerExprIsPureIfPerfectNest; } - bool anyImpliedDo{false}; - bool anyArrayExpr{false}; - bool isPerfectLoopNest{true}; + bool anyImpliedDo = false; + bool anyArrayExpr = false; + bool isPerfectLoopNest = true; + bool innerExprIsPureIfPerfectNest = false; std::int64_t innerNumberOfExprIfPrefectNest = 0; std::int64_t depthIfPerfectLoopNest = 0; }; @@ -330,6 +417,7 @@ template ArrayCtorAnalysis::ArrayCtorAnalysis( + Fortran::evaluate::FoldingContext &foldingContext, const Fortran::evaluate::ArrayConstructor &arrayCtorExpr) { llvm::SmallVector *> arrayValueListStack{&arrayCtorExpr}; @@ -339,8 +427,10 @@ std::int64_t localNumberOfExpr = 0; // Loop though the ac-value of an ac-value list, and add any nested // ac-value-list of ac-implied-do to the stack. + const Fortran::evaluate::ArrayConstructorValues *currentArrayValueList = + arrayValueListStack.pop_back_val(); for (const Fortran::evaluate::ArrayConstructorValue &acValue : - *arrayValueListStack.pop_back_val()) + *currentArrayValueList) std::visit(Fortran::common::visitors{ [&](const Fortran::evaluate::ImpliedDo &impledDo) { arrayValueListStack.push_back(&impledDo.values()); @@ -355,11 +445,16 @@ if (localNumberOfImpliedDo == 0) { // Leaf ac-value-list in the array constructor ac-value tree. - if (isPerfectLoopNest) + if (isPerfectLoopNest) { // This this the only leaf of the array-constructor (the array // constructor is a nest of single implied-do with a list of expression // in the last deeper implied do). e.g: "[((i+j, i=1,n)j=1,m)]". innerNumberOfExprIfPrefectNest = localNumberOfExpr; + if (localNumberOfExpr == 1) + innerExprIsPureIfPerfectNest = !Fortran::evaluate::FindImpureCall( + foldingContext, toEvExpr(std::get>( + currentArrayValueList->begin()->u))); + } } else if (localNumberOfImpliedDo == 1 && localNumberOfExpr == 0) { // Perfect implied-do nest new level. ++depthIfPerfectLoopNest; @@ -432,7 +527,7 @@ mlir::Type elementType = LengthAndTypeCollector::collect( loc, converter, arrayCtorExpr, symMap, stmtCtx, lengths); // Run an analysis of the array constructor ac-value. - ArrayCtorAnalysis analysis(arrayCtorExpr); + ArrayCtorAnalysis analysis(converter.getFoldingContext(), arrayCtorExpr); bool needToEvaluateOneExprToGetLengthParameters = failedToGatherLengthParameters(elementType, lengths); @@ -443,8 +538,11 @@ TODO(loc, "Lowering of array constructor requiring the runtime"); auto declaredType = fir::SequenceType::get({typeExtent}, elementType); - if (analysis.isSingleImpliedDoWithOneScalarExpr()) - TODO(loc, "Lowering of array constructor as hlfir.elemental"); + // Note: array constructors containing impure ac-value expr are currently not + // rewritten to hlfir.elemental because impure expressions should be evaluated + // in order, and hlfir.elemental currently misses a way to indicate that. + if (analysis.isSingleImpliedDoWithOneScalarPureExpr()) + return AsElementalStrategy(loc, builder, declaredType, extent, lengths); if (analysis.anyImpliedDo) return InlinedTempStrategy(loc, builder, declaredType, extent, lengths); diff --git a/flang/test/Lower/HLFIR/array-ctor-as-elemental.f90 b/flang/test/Lower/HLFIR/array-ctor-as-elemental.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/HLFIR/array-ctor-as-elemental.f90 @@ -0,0 +1,110 @@ +! Test lowering of array constructors as hlfir.elemental. +! RUN: bbc -emit-fir -hlfir -o - %s | FileCheck %s + +subroutine test_as_simple_elemental(n) + integer :: n + call takes_int([(n+i, i=1,4)]) +end subroutine +! CHECK-LABEL: func.func @_QPtest_as_simple_elemental( +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare {{.*}}En +! CHECK: %[[VAL_2:.*]] = arith.constant 4 : index +! CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (i64) -> index +! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i64) -> index +! CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<4xi32> { +! CHECK: ^bb0(%[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_10]], %[[VAL_8]] : index +! CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_7]] : index +! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_5]], %[[VAL_12]] : index +! CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref +! CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_13]] : (index) -> i32 +! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : i32 +! CHECK: hlfir.yield_element %[[VAL_16]] : i32 +! CHECK: } +! CHECK: fir.call +! CHECK: hlfir.destroy %[[VAL_9]] : !hlfir.expr<4xi32> + +subroutine test_as_strided_elemental(lb, ub, stride) + integer(8) :: lb, ub, stride + call takes_int([(i, i=lb,ub,stride)]) +end subroutine +! CHECK-LABEL: func.func @_QPtest_as_strided_elemental( +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Elb +! CHECK: %[[VAL_4:.*]]:2 = hlfir.declare {{.*}}Estride +! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Eub +! CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64 +! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref +! CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref +! CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_8]] : i64 +! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref +! CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : i64 +! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref +! CHECK: %[[VAL_13:.*]] = arith.divsi %[[VAL_11]], %[[VAL_12]] : i64 +! CHECK: %[[VAL_14:.*]] = arith.constant 0 : i64 +! CHECK: %[[VAL_15:.*]] = arith.cmpi sgt, %[[VAL_13]], %[[VAL_14]] : i64 +! CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_13]], %[[VAL_14]] : i64 +! CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_6]], %[[VAL_16]] : i64 +! CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (i64) -> index +! CHECK: %[[VAL_19:.*]] = fir.shape %[[VAL_18]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_3]]#0 : !fir.ref +! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (i64) -> index +! CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref +! CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_22]] : (i64) -> index +! CHECK: %[[VAL_24:.*]] = arith.constant 1 : index +! CHECK: %[[VAL_25:.*]] = hlfir.elemental %[[VAL_19]] : (!fir.shape<1>) -> !hlfir.expr { +! CHECK: ^bb0(%[[VAL_26:.*]]: index): +! CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_26]], %[[VAL_24]] : index +! CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_27]], %[[VAL_23]] : index +! CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_21]], %[[VAL_28]] : index +! CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_29]] : (index) -> i32 +! CHECK: hlfir.yield_element %[[VAL_30]] : i32 +! CHECK: } +! CHECK: fir.call +! CHECK: hlfir.destroy %[[VAL_25]] : !hlfir.expr + +subroutine test_as_elemental_with_pure_call(n) + interface + integer pure function foo(i) + integer, value :: i + end function + end interface + integer :: n + call takes_int([(foo(i), i=1,4)]) +end subroutine +! CHECK-LABEL: func.func @_QPtest_as_elemental_with_pure_call( +! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFtest_as_elemental_with_pure_callEn"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_2:.*]] = arith.constant 4 : index +! CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_4:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (i64) -> index +! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i64) -> index +! CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<4xi32> { +! CHECK: ^bb0(%[[VAL_10:.*]]: index): +! CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_10]], %[[VAL_8]] : index +! CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_7]] : index +! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_5]], %[[VAL_12]] : index +! CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (index) -> i32 +! CHECK: %[[VAL_15:.*]] = fir.call @_QPfoo(%[[VAL_14]]) fastmath : (i32) -> i32 +! CHECK: hlfir.yield_element %[[VAL_15]] : i32 +! CHECK: } +! CHECK: fir.call +! CHECK: hlfir.destroy %[[VAL_9]] : !hlfir.expr<4xi32> + +! CHECK-LABEL: func.func @_QPtest_with_impure_call( +subroutine test_with_impure_call(n) + interface + integer function impure_foo(i) + integer, value :: i + end function + end interface + integer :: n + call takes_int([(impure_foo(i), i=1,4)]) +end subroutine +! CHECK-NOT: hlfir.elemental +! CHECK: return