diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -3557,10 +3557,38 @@ ael.lowerElementalSubroutine(call); } + static const std::optional + extractPassedArgFromProcRef(const Fortran::evaluate::ProcedureRef &procRef, + Fortran::lower::AbstractConverter &converter) { + // First look for passed object in actual arguments. + for (const std::optional &arg : + procRef.arguments()) + if (arg && arg->isPassedObject()) + return arg; + + // If passed object is not found by here, it means the call was fully + // resolved to the correct procedure. Look for the pass object in the + // dummy arguments. Pick the first polymorphic one. + Fortran::lower::CallerInterface caller(procRef, converter); + unsigned idx = 0; + for (const auto &arg : caller.characterize().dummyArguments) { + if (const auto *dummy = + std::get_if( + &arg.u)) + if (dummy->type.type().IsPolymorphic()) + return procRef.arguments()[idx]; + ++idx; + } + return std::nullopt; + } + // TODO: See the comment in genarr(const Fortran::lower::Parentheses&). // This is skipping generation of copy-in/copy-out code for analysis that is // required when arguments are in parentheses. void lowerElementalSubroutine(const Fortran::lower::SomeExpr &call) { + if (const auto *procRef = + std::get_if(&call.u)) + setLoweredProcRef(procRef); auto f = genarr(call); llvm::SmallVector shape = genIterationShape(); auto [iterSpace, insPt] = genImplicitLoops(shape, /*innerArg=*/{}); @@ -3979,6 +4007,17 @@ // Otherwise, use the first ArrayLoad operand shape. if (!arrayOperands.empty()) return getShape(getInducingShapeArrayOperand()); + // Otherwise, in elemental context, try to find the passed object and + // retrived the iteration from it. + if (loweredProcRef && loweredProcRef->IsElemental()) { + const std::optional passArg = + extractPassedArgFromProcRef(*loweredProcRef, converter); + if (passArg) { + ExtValue exv = asScalarRef(*passArg->UnwrapExpr()); + fir::FirOpBuilder *builder = &converter.getFirOpBuilder(); + return fir::factory::getExtents(getLoc(), *builder, exv); + } + } fir::emitFatalError(getLoc(), "failed to compute the array expression shape"); } @@ -4660,24 +4699,23 @@ ExtValue exv = asScalarRef(*expr); mlir::Value tdesc; if (fir::isPolymorphicType(fir::getBase(exv).getType())) { - mlir::Type tdescType = - fir::TypeDescType::get(mlir::NoneType::get(builder.getContext())); - tdesc = builder.create( - loc, tdescType, fir::getBase(exv)); + mlir::Type tdescType = fir::TypeDescType::get( + mlir::NoneType::get(builder.getContext())); + tdesc = builder.create(loc, tdescType, + fir::getBase(exv)); } mlir::Type baseTy = fir::dyn_cast_ptrOrBoxEleTy(fir::getBase(exv).getType()); - mlir::Type innerTy = llvm::TypeSwitch(baseTy) - .Case([](auto ty) { return ty.getEleTy(); }) - .Default([](mlir::Type t) {return t; }); - + mlir::Type innerTy = fir::unwrapSequenceType(baseTy); operands.emplace_back([=](IterSpace iters) -> ExtValue { mlir::Value coord = builder.create( - loc, fir::ReferenceType::get(innerTy), fir::getBase(exv), iters.iterVec()); + loc, fir::ReferenceType::get(innerTy), fir::getBase(exv), + iters.iterVec()); mlir::Value empty; mlir::ValueRange emptyRange; - return builder.create(loc, fir::ClassType::get(innerTy), - coord, empty, empty, emptyRange, tdesc); + return builder.create( + loc, fir::ClassType::get(innerTy), coord, empty, empty, + emptyRange, tdesc); }); } else { PushSemantics(ConstituentSemantics::BoxValue); @@ -4757,6 +4795,7 @@ CC genProcRef(const Fortran::evaluate::ProcedureRef &procRef, llvm::Optional retTy) { mlir::Location loc = getLoc(); + setLoweredProcRef(&procRef); if (isOptimizableTranspose(procRef, converter)) return genTransposeProcRef(procRef); @@ -7018,6 +7057,10 @@ ubounds = ubs; } + void setLoweredProcRef(const Fortran::evaluate::ProcedureRef *procRef) { + loweredProcRef = procRef; + } + Fortran::lower::AbstractConverter &converter; fir::FirOpBuilder &builder; Fortran::lower::StatementContext &stmtCtx; @@ -7047,6 +7090,9 @@ // Can the array expression be evaluated in any order? // Will be set to false if any of the expression parts prevent this. bool unordered = true; + // ProcedureRef currently being lowered. Used to retrieve the iteration shape + // in elemental context with passed object. + const Fortran::evaluate::ProcedureRef *loweredProcRef = nullptr; }; } // namespace diff --git a/flang/test/Lower/polymorphic.f90 b/flang/test/Lower/polymorphic.f90 --- a/flang/test/Lower/polymorphic.f90 +++ b/flang/test/Lower/polymorphic.f90 @@ -10,6 +10,8 @@ procedure :: print procedure :: assign_p1_int procedure :: elemental_fct + procedure :: elemental_sub + procedure, pass(this) :: elemental_sub_pass generic :: assignment(=) => assign_p1_int procedure :: host_assoc end type @@ -50,10 +52,21 @@ ! CHECK: fir.call @_QMpolymorphic_testFhost_assocPinternal(%[[TUPLE]]) {{.*}} : (!fir.ref>>>) -> () elemental integer function elemental_fct(this) - class(p1), intent(In) :: this + class(p1), intent(in) :: this elemental_fct = this%a end function + elemental subroutine elemental_sub(this) + class(p1), intent(inout) :: this + this%a = this%a * this%b + end subroutine + + elemental subroutine elemental_sub_pass(c, this) + integer, intent(in) :: c + class(p1), intent(inout) :: this + this%a = this%a * this%b + c + end subroutine + ! Test correct access to polymorphic entity component. subroutine component_access(p) class(p1) :: p @@ -543,4 +556,117 @@ ! CHECK: %{{.*}} = fir.call @_FortranAioOutputDescriptor(%{{.*}}, %[[BOX_NONE]]) fastmath : (!fir.ref, !fir.box) -> i1 ! CHECK: fir.freemem %[[TMP]] : !fir.heap> + subroutine test_elemental_sub_array() + type(p1) :: t(10) + call t%elemental_sub() + call t%elemental_sub_pass(2) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPtest_elemental_sub_array() { +! CHECK: %[[C10:.*]] = arith.constant 10 : index +! CHECK: %[[T:.*]] = fir.alloca !fir.array<10x!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {bindc_name = "t", uniq_name = "_QMpolymorphic_testFtest_elemental_sub_arrayEt"} +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[C10]], %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[T]], %[[IND]] : (!fir.ref>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] : (!fir.ref>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPelemental_sub(%[[EMBOXED]]) {{.*}} : (!fir.class>) -> () +! CHECK: } +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[C10]], %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[T]], %[[IND]] : (!fir.ref>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] : (!fir.ref>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPelemental_sub_pass(%{{.*}}, %[[EMBOXED]]) {{.*}} : (!fir.ref, !fir.class>) -> () +! CHECK: } + + subroutine test_elemental_sub_poly_array(p) + class(p1) :: p(10) + call p%elemental_sub() + call p%elemental_sub_pass(3) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPtest_elemental_sub_poly_array( +! CHECK-SAME: %[[P:.*]]: !fir.class>> {fir.bindc_name = "p"}) { +! CHECK: %[[C10:.*]] = arith.constant 10 : index +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[P]] : (!fir.class>>) -> !fir.tdesc +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[C10]], %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[P]], %[[IND]] : (!fir.class>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.dispatch "elemental_sub"(%[[EMBOXED]] : !fir.class>) (%[[EMBOXED]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: } +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[P]] : (!fir.class>>) -> !fir.tdesc +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[C10]], %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[P]], %[[IND]] : (!fir.class>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.dispatch "elemental_sub_pass"(%[[EMBOXED]] : !fir.class>) (%{{.*}}, %[[EMBOXED]] : !fir.ref, !fir.class>) {pass_arg_pos = 1 : i32} +! CHECK: } + + subroutine test_elemental_sub_array_assumed(t) + type(p1) :: t(:) + call t%elemental_sub() + call t%elemental_sub_pass(4) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPtest_elemental_sub_array_assumed( +! CHECK-SAME: %[[T:.*]]: !fir.box>> {fir.bindc_name = "t"}) { +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[T_DIMS:.*]]:3 = fir.box_dims %[[T]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[T_DIMS]]#1, %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[T]], %[[IND]] : (!fir.box>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] : (!fir.ref>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPelemental_sub(%[[EMBOXED]]) {{.*}} : (!fir.class>) -> () +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[T_DIMS:.*]]:3 = fir.box_dims %[[T]], %[[C0]] : (!fir.box>>, index) -> (index, index, index) +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[T_DIMS]]#1, %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[T]], %[[IND]] : (!fir.box>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] : (!fir.ref>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPelemental_sub_pass(%{{.*}}, %[[EMBOXED]]) {{.*}} : (!fir.ref, !fir.class>) -> () +! CHECK: } + + subroutine test_elemental_sub_poly_array_assumed(p) + class(p1) :: p(:) + call p%elemental_sub() + call p%elemental_sub_pass(5) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPtest_elemental_sub_poly_array_assumed( +! CHECK-SAME: %[[P:.*]]: !fir.class>> {fir.bindc_name = "p"}) { +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[P]] : (!fir.class>>) -> !fir.tdesc +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[P_DIMS:.*]]:3 = fir.box_dims %[[P]], %[[C0]] : (!fir.class>>, index) -> (index, index, index) +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[P_DIMS]]#1, %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[P]], %[[IND]] : (!fir.class>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.dispatch "elemental_sub"(%[[EMBOXED]] : !fir.class>) (%[[EMBOXED]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: } +! CHECK: %[[TDESC:.*]] = fir.box_tdesc %[[P]] : (!fir.class>>) -> !fir.tdesc +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[P_DIMS:.*]]:3 = fir.box_dims %[[P]], %[[C0]] : (!fir.class>>, index) -> (index, index, index) +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[UB:.*]] = arith.subi %[[P_DIMS]]#1, %[[C1]] : index +! CHECK: fir.do_loop %[[IND:.*]] = %[[C0]] to %[[UB]] step %[[C1]] { +! CHECK: %[[COORD:.*]] = fir.coordinate_of %[[P]], %[[IND]] : (!fir.class>>, index) -> !fir.ref> +! CHECK: %[[EMBOXED:.*]] = fir.embox %[[COORD]] tdesc %[[TDESC]] : (!fir.ref>, !fir.tdesc) -> !fir.class> +! CHECK: fir.dispatch "elemental_sub_pass"(%[[EMBOXED]] : !fir.class>) (%{{.*}}, %[[EMBOXED]] : !fir.ref, !fir.class>) {pass_arg_pos = 1 : i32} +! CHECK: } + end module