diff --git a/flang/include/flang/Lower/CallInterface.h b/flang/include/flang/Lower/CallInterface.h --- a/flang/include/flang/Lower/CallInterface.h +++ b/flang/include/flang/Lower/CallInterface.h @@ -284,6 +284,14 @@ /// procedure. bool isIndirectCall() const; + /// Returns true if this is a call of a type-bound procedure with a + /// polymorphic entity. + bool requireDispatchCall() const; + + /// Get the passed-object argument index. nullopt if there is no passed-object + /// index. + std::optional getPassArgIndex() const; + /// Return the procedure symbol if this is a call to a user defined /// procedure. const Fortran::semantics::Symbol *getProcedureSymbol() const; @@ -372,6 +380,10 @@ /// called through pointers or not. bool isIndirectCall() const { return false; } + /// On the callee side it does not matter whether the procedure is called + /// through dynamic dispatch or not. + bool requireDispatchCall() const { return false; }; + /// Return the procedure symbol if this is a call to a user defined /// procedure. const Fortran::semantics::Symbol *getProcedureSymbol() const; diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -203,7 +203,7 @@ } /// Get the memory reference type of the data pointer from the box type, -inline mlir::Type boxMemRefType(fir::BoxType t) { +inline mlir::Type boxMemRefType(fir::BaseBoxType t) { auto eleTy = t.getEleTy(); if (!eleTy.isa()) eleTy = fir::ReferenceType::get(t); diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -88,6 +88,36 @@ return false; } +bool Fortran::lower::CallerInterface::requireDispatchCall() const { + // calls with NOPASS attribute still have their component so check if it is + // polymorphic. + if (const Fortran::evaluate::Component *component = + procRef.proc().GetComponent()) { + if (Fortran::semantics::IsPolymorphic(component->GetFirstSymbol())) + return true; + } + // calls with PASS attribute have the passed-object already set in its + // arguments. Just check if their is one. + std::optional passArg = getPassArgIndex(); + if (passArg) + return true; + return false; +} + +std::optional +Fortran::lower::CallerInterface::getPassArgIndex() const { + unsigned passArgIdx = 0; + std::optional passArg = std::nullopt; + for (const auto &arg : getCallDescription().arguments()) { + if (arg && arg->isPassedObject()) { + passArg = passArgIdx; + break; + } + ++passArgIdx; + } + return passArg; +} + const Fortran::semantics::Symbol * Fortran::lower::CallerInterface::getIfIndirectCallSymbol() const { if (const Fortran::semantics::Symbol *symbol = procRef.proc().GetSymbol()) diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -1993,8 +1993,10 @@ } mlir::Value base = fir::getBase(array); - auto seqTy = - fir::dyn_cast_ptrOrBoxEleTy(base.getType()).cast(); + mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(base.getType()); + if (auto classTy = eleTy.dyn_cast()) + eleTy = classTy.getEleTy(); + auto seqTy = eleTy.cast(); assert(args.size() == seqTy.getDimension()); mlir::Type ty = builder.getRefType(seqTy.getEleTy()); auto addr = builder.create(loc, ty, base, args); @@ -2727,11 +2729,47 @@ if (addHostAssociations) operands.push_back(converter.hostAssocTupleValue()); - auto call = builder.create(loc, funcType.getResults(), - funcSymbolAttr, operands); + mlir::Value callResult; + unsigned callNumResults; + if (caller.requireDispatchCall()) { + // Procedure call requiring a dynamic dispatch. Call is created with + // fir.dispatch. + + // Get the raw procedure name. The procedure name is not mangled in the + // binding table. + const auto &ultimateSymbol = + caller.getCallDescription().proc().GetSymbol()->GetUltimate(); + auto procName = toStringRef(ultimateSymbol.name()); + + fir::DispatchOp dispatch; + if (std::optional passArg = caller.getPassArgIndex()) { + // PASS, PASS(arg-name) + dispatch = builder.create( + loc, funcType.getResults(), procName, operands[*passArg], operands, + builder.getI32IntegerAttr(*passArg)); + } else { + // NOPASS + const Fortran::evaluate::Component *component = + caller.getCallDescription().proc().GetComponent(); + assert(component && "expect component for type-bound procedure call."); + fir::ExtendedValue pass = + symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue(); + dispatch = builder.create(loc, funcType.getResults(), + procName, fir::getBase(pass), + operands, nullptr); + } + callResult = dispatch.getResult(0); + callNumResults = dispatch.getNumResults(); + } else { + // Standard procedure call with fir.call. + auto call = builder.create(loc, funcType.getResults(), + funcSymbolAttr, operands); + callResult = call.getResult(0); + callNumResults = call.getNumResults(); + } if (caller.mustSaveResult()) - builder.create(loc, call.getResult(0), + builder.create(loc, callResult, fir::getBase(allocatedResult.value()), arrayResultShape, resultLengths); @@ -2754,7 +2792,7 @@ return mlir::Value{}; // subroutine call // For now, Fortran return values are implemented with a single MLIR // function return value. - assert(call.getNumResults() == 1 && + assert(callNumResults == 1 && "Expected exactly one result in FUNCTION call"); // Call a BIND(C) function that return a char. @@ -2764,10 +2802,10 @@ funcType.getResults()[0].dyn_cast(); mlir::Value len = builder.createIntegerConstant( loc, builder.getCharacterLengthType(), charTy.getLen()); - return fir::CharBoxValue{call.getResult(0), len}; + return fir::CharBoxValue{callResult, len}; } - return call.getResult(0); + return callResult; } /// Like genExtAddr, but ensure the address returned is a temporary even if \p @@ -6012,7 +6050,7 @@ } static mlir::Type unwrapBoxEleTy(mlir::Type ty) { - if (auto boxTy = ty.dyn_cast()) + if (auto boxTy = ty.dyn_cast()) return fir::unwrapRefType(boxTy.getEleTy()); return ty; } @@ -7150,7 +7188,7 @@ // Need an intermediate dereference if the boxed value // appears in the middle of the component path or if it is // on the right and this is not a pointer assignment. - if (auto boxTy = ty.dyn_cast()) { + if (auto boxTy = ty.dyn_cast()) { auto currentFunc = components.getExtendCoorRef(); auto loc = getLoc(); auto *bldr = &converter.getFirOpBuilder(); @@ -7161,7 +7199,7 @@ deref = true; } } - } else if (auto boxTy = ty.dyn_cast()) { + } else if (auto boxTy = ty.dyn_cast()) { ty = fir::unwrapRefType(boxTy.getEleTy()); auto recTy = ty.cast(); ty = recTy.getType(name); @@ -7247,7 +7285,7 @@ // assignment, then insert the dereference of the box before any // conversion and store. if (!isPointerAssignment()) { - if (auto boxTy = eleTy.dyn_cast()) { + if (auto boxTy = eleTy.dyn_cast()) { eleTy = fir::boxMemRefType(boxTy); addr = builder.create(loc, eleTy, addr); eleTy = fir::unwrapRefType(eleTy); diff --git a/flang/lib/Lower/Mangler.cpp b/flang/lib/Lower/Mangler.cpp --- a/flang/lib/Lower/Mangler.cpp +++ b/flang/lib/Lower/Mangler.cpp @@ -155,6 +155,10 @@ llvm::report_fatal_error( "only derived type instances can be mangled"); }, + [&](const Fortran::semantics::ProcBindingDetails &procBinding) + -> std::string { + return mangleName(procBinding.symbol(), keepExternalInScope); + }, [](const auto &) -> std::string { TODO_NOLOC("symbol mangling"); }, }, ultimateSymbol.details()); diff --git a/flang/lib/Optimizer/Builder/BoxValue.cpp b/flang/lib/Optimizer/Builder/BoxValue.cpp --- a/flang/lib/Optimizer/Builder/BoxValue.cpp +++ b/flang/lib/Optimizer/Builder/BoxValue.cpp @@ -204,7 +204,7 @@ /// Debug verifier for BoxValue ctor. There is no guarantee this will /// always be called. bool fir::BoxValue::verify() const { - if (!addr.getType().isa()) + if (!addr.getType().isa()) return false; if (!lbounds.empty() && lbounds.size() != rank()) return false; diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -460,7 +460,7 @@ mlir::Value fir::FirOpBuilder::createBox(mlir::Location loc, const fir::ExtendedValue &exv) { mlir::Value itemAddr = fir::getBase(exv); - if (itemAddr.getType().isa()) + if (itemAddr.getType().isa()) return itemAddr; auto elementType = fir::dyn_cast_ptrEleTy(itemAddr.getType()); if (!elementType) { @@ -741,7 +741,7 @@ fir::FirOpBuilder &builder, mlir::Type valTy, mlir::Value boxVal) { - if (auto boxTy = valTy.dyn_cast()) { + if (auto boxTy = valTy.dyn_cast()) { auto eleTy = fir::unwrapAllRefAndSeqType(boxTy.getEleTy()); if (auto recTy = eleTy.dyn_cast()) { if (recTy.getNumLenParams() > 0) { @@ -795,7 +795,7 @@ fir::factory::getTypeParams(mlir::Location loc, fir::FirOpBuilder &builder, fir::ArrayLoadOp load) { mlir::Type memTy = load.getMemref().getType(); - if (auto boxTy = memTy.dyn_cast()) + if (auto boxTy = memTy.dyn_cast()) return getFromBox(loc, builder, boxTy, load.getMemref()); return load.getTypeparams(); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -917,7 +917,8 @@ (inType.isa() && outType.isa()) || (inType.isa() && outType.isa()) || (fir::isa_complex(inType) && fir::isa_complex(outType)) || - (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType))) + (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || + (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType))) return mlir::success(); return emitOpError("invalid type conversion"); } diff --git a/flang/test/Lower/dispatch.f90 b/flang/test/Lower/dispatch.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/dispatch.f90 @@ -0,0 +1,176 @@ +! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s + +! Tests the different possible type involving polymorphic entities. + +module call_dispatch + + interface + subroutine nopass_defferred(x) + real :: x(:) + end subroutine + end interface + + type p1 + integer :: a + integer :: b + contains + procedure, nopass :: tbp_nopass + procedure :: tbp_pass + procedure, pass(this) :: tbp_pass_arg0 + procedure, pass(this) :: tbp_pass_arg1 + + procedure, nopass :: proc1 => p1_proc1_nopass + procedure :: proc2 => p1_proc2 + procedure, pass(this) :: proc3 => p1_proc3_arg0 + procedure, pass(this) :: proc4 => p1_proc4_arg1 + + procedure, nopass :: p1_fct1_nopass + procedure :: p1_fct2 + procedure, pass(this) :: p1_fct3_arg0 + procedure, pass(this) :: p1_fct4_arg1 + end type + + type, abstract :: a1 + real :: a + real :: b + contains + procedure(nopass_defferred), deferred, nopass :: nopassd + end type + + contains + +! ------------------------------------------------------------------------------ +! Test lowering of type-bound procedure call on polymorphic entities +! ------------------------------------------------------------------------------ + + function p1_fct1_nopass() + real :: p1_fct1_nopass + end function + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct1_nopass() -> f32 + + function p1_fct2(p) + real :: p1_fct2 + class(p1) :: p + end function + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct2(%{{.*}}: !fir.class>) -> f32 + + function p1_fct3_arg0(this) + real :: p1_fct2 + class(p1) :: this + end function + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct3_arg0(%{{.*}}: !fir.class>) -> f32 + + function p1_fct4_arg1(i, this) + real :: p1_fct2 + integer :: i + class(p1) :: this + end function + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_fct4_arg1(%{{.*}}: !fir.ref, %{{.*}}: !fir.class>) -> f32 + + subroutine p1_proc1_nopass() + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc1_nopass() + + subroutine p1_proc2(p) + class(p1) :: p + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc2(%{{.*}}: !fir.class>) + + subroutine p1_proc3_arg0(this) + class(p1) :: this + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc3_arg0(%{{.*}}: !fir.class>) + + subroutine p1_proc4_arg1(i, this) + integer, intent(in) :: i + class(p1) :: this + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPp1_proc4_arg1(%{{.*}}: !fir.ref, %{{.*}}: !fir.class>) + + subroutine tbp_nopass() + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_nopass() + + subroutine tbp_pass(t) + class(p1) :: t + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass(%{{.*}}: !fir.class>) + + subroutine tbp_pass_arg0(this) + class(p1) :: this + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass_arg0(%{{.*}}: !fir.class>) + + subroutine tbp_pass_arg1(i, this) + integer, intent(in) :: i + class(p1) :: this + end subroutine + ! CHECK-LABEL: func.func @_QMcall_dispatchPtbp_pass_arg1(%{{.*}}: !fir.ref, %{{.*}}: !fir.class>) + + subroutine check_dispatch(p) + class(p1) :: p + real :: a + + call p%tbp_nopass() + call p%tbp_pass() + call p%tbp_pass_arg0() + call p%tbp_pass_arg1(1) + + call p%proc1() + call p%proc2() + call p%proc3() + call p%proc4(1) + + a = p%p1_fct1_nopass() + a = p%p1_fct2() + a = p%p1_fct3_arg0() + a = p%p1_fct4_arg1(1) + end subroutine + +! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_dispatch( +! CHECK-SAME: %[[P:.*]]: !fir.class> {fir.bindc_name = "p"}) { +! CHECK: fir.dispatch "tbp_nopass"(%[[P]] : !fir.class>){{$}} +! CHECK: fir.dispatch "tbp_pass"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: fir.dispatch "tbp_pass_arg0"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: fir.dispatch "tbp_pass_arg1"(%[[P]] : !fir.class>) (%{{.*}}, %[[P]] : !fir.ref, !fir.class>) {pass_arg_pos = 1 : i32} + +! CHECK: fir.dispatch "proc1"(%[[P]] : !fir.class>){{$}} +! CHECK: fir.dispatch "proc2"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: fir.dispatch "proc3"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) {pass_arg_pos = 0 : i32} +! CHECK: fir.dispatch "proc4"(%[[P]] : !fir.class>) (%{{.*}}, %[[P]] : !fir.ref, !fir.class>) {pass_arg_pos = 1 : i32} + +! CHECK: %{{.*}} = fir.dispatch "p1_fct1_nopass"(%[[P]] : !fir.class>) -> f32{{$}} +! CHECK: %{{.*}} = fir.dispatch "p1_fct2"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) -> f32 {pass_arg_pos = 0 : i32} +! CHECK: %{{.*}} = fir.dispatch "p1_fct3_arg0"(%[[P]] : !fir.class>) (%[[P]] : !fir.class>) -> f32 {pass_arg_pos = 0 : i32} +! CHECK: %{{.*}} = fir.dispatch "p1_fct4_arg1"(%[[P]] : !fir.class>) (%{{.*}}, %[[P]] : !fir.ref, !fir.class>) -> f32 {pass_arg_pos = 1 : i32} + + subroutine check_dispatch_deferred(a, x) + class(a1) :: a + real :: x(:) + call a%nopassd(x) + end subroutine + +! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_dispatch_deferred( +! CHECK-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"}, +! CHECK-SAME: %[[ARG1:.*]]: !fir.box> {fir.bindc_name = "x"}) { +! CHECK: fir.dispatch "nopassd"(%[[ARG0]] : !fir.class>) (%[[ARG1]] : !fir.box>) + +! ------------------------------------------------------------------------------ +! Test that direct call is emitted when the type is known +! ------------------------------------------------------------------------------ + + subroutine check_nodispatch(t) + type(p1) :: t + call t%tbp_nopass() + call t%tbp_pass() + call t%tbp_pass_arg0() + call t%tbp_pass_arg1(1) + end subroutine + +! CHECK-LABEL: func.func @_QMcall_dispatchPcheck_nodispatch +! CHECK: fir.call @_QMcall_dispatchPtbp_nopass +! CHECK: fir.call @_QMcall_dispatchPtbp_pass +! CHECK: fir.call @_QMcall_dispatchPtbp_pass_arg0 +! CHECK: fir.call @_QMcall_dispatchPtbp_pass_arg1 + +end module