diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -2754,9 +2754,15 @@ assert(component && "expect component for type-bound procedure call."); fir::ExtendedValue pass = symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue(); + mlir::Value passObject = fir::getBase(pass); + if (fir::isa_ref_type(passObject.getType())) + passObject = builder.create( + loc, + passObject.getType().dyn_cast().getEleTy(), + passObject); dispatch = builder.create( loc, funcType.getResults(), builder.getStringAttr(procName), - fir::getBase(pass), operands, nullptr); + passObject, operands, nullptr); } callResult = dispatch.getResult(0); callNumResults = dispatch.getNumResults(); diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -35,6 +35,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/TypeSwitch.h" namespace fir { #define GEN_PASS_DEF_FIRTOLLVMLOWERING @@ -898,15 +899,13 @@ if (bindingTables.empty()) return emitError(loc) << "no binding tables found"; - if (dispatch.getObject() - .getType() - .getEleTy() - .isa()) - TODO(loc, - "fir.dispatch with allocatable or pointer polymorphic entities"); - // Get derived type information. - auto declaredType = dispatch.getObject().getType().getEleTy(); + auto declaredType = llvm::TypeSwitch( + dispatch.getObject().getType().getEleTy()) + .Case( + [](auto p) { return p.getEleTy(); }) + .Default([](mlir::Type t) { return t; }); + assert(declaredType.isa() && "expecting fir.type"); auto recordType = declaredType.dyn_cast(); std::string typeDescName = diff --git a/flang/test/Lower/allocatable-polymorphic.f90 b/flang/test/Lower/allocatable-polymorphic.f90 --- a/flang/test/Lower/allocatable-polymorphic.f90 +++ b/flang/test/Lower/allocatable-polymorphic.f90 @@ -5,11 +5,24 @@ type p1 integer :: a integer :: b + contains + procedure, nopass :: proc1 => proc1_p1 end type type, extends(p1) :: p2 integer :: c + contains + procedure, nopass :: proc1 => proc1_p2 end type + +contains + subroutine proc1_p1() + print*, 'call proc1_p1' + end subroutine + + subroutine proc1_p2() + print*, 'call proc1_p2' + end subroutine end module program test_allocatable @@ -27,6 +40,8 @@ allocate(p1::c3(10)) allocate(p2::c4(20)) + call c1%proc1() + call c2%proc1() end ! CHECK-LABEL: func.func @_QQmain()