diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -836,7 +836,8 @@ let results = (outs BoxOrClassType); let assemblyFormat = [{ - $box (`(` $shape^ `)`)? (`[` $slice^ `]`)? attr-dict `:` functional-type(operands, results) + $box (`(` $shape^ `)`)? (`[` $slice^ `]`)? + attr-dict `:` functional-type(operands, results) }]; let hasVerifier = 1; diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -91,6 +91,11 @@ fir::emitFatalError(val.getLoc(), "must be a constant"); } +static unsigned getTypeDescFieldId(mlir::Type ty) { + auto isArray = fir::dyn_cast_ptrOrBoxEleTy(ty).isa(); + return isArray ? kOptTypePtrPosInBox : kDimsPosInBox; +} + namespace { /// FIR conversion pattern template template @@ -244,6 +249,18 @@ return getBoxEleTy(type, {kAddrPosInBox}); } + /// Read the address of the type descriptor from a box. + mlir::Value + loadTypeDescAddress(mlir::Location loc, mlir::Type ty, mlir::Value box, + mlir::ConversionPatternRewriter &rewriter) const { + unsigned typeDescFieldId = getTypeDescFieldId(ty); + mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext()); + auto pty = mlir::LLVM::LLVMPointerType::get(tdescType); + mlir::LLVM::GEPOp p = genGEP(loc, pty, rewriter, box, 0, + static_cast(typeDescFieldId)); + return rewriter.create(loc, tdescType, p); + } + // Load the attribute from the \p box and perform a check against \p maskValue // The final comparison is implemented as `(attribute & maskValue) != 0`. mlir::Value genBoxAttributeCheck(mlir::Location loc, mlir::Value box, @@ -940,9 +957,7 @@ typeDescTy = global.getType(); } - auto isArray = fir::dyn_cast_ptrOrBoxEleTy(passedObject.getType()) - .template isa(); - unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox; + unsigned typeDescFieldId = getTypeDescFieldId(passedObject.getType()); auto descPtr = adaptor.getOperands()[0] .getType() @@ -1458,7 +1473,8 @@ template std::tuple consDescriptorPrefix(BOX box, mlir::ConversionPatternRewriter &rewriter, - unsigned rank, mlir::ValueRange lenParams) const { + unsigned rank, mlir::ValueRange lenParams, + mlir::Value typeDesc = {}) const { auto loc = box.getLoc(); auto boxTy = box.getType().template dyn_cast(); auto convTy = this->lowerTy().convertBoxType(boxTy, rank); @@ -1492,11 +1508,10 @@ this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0)); if (hasAddendum) { - auto isArray = - fir::dyn_cast_ptrOrBoxEleTy(boxTy).template isa(); - unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox; - auto typeDesc = - getTypeDescriptor(box, rewriter, loc, unwrapIfDerived(boxTy)); + unsigned typeDescFieldId = getTypeDescFieldId(boxTy); + if (!typeDesc) + typeDesc = + getTypeDescriptor(box, rewriter, loc, unwrapIfDerived(boxTy)); descriptor = insertField(rewriter, loc, descriptor, {typeDescFieldId}, typeDesc, /*bitCast=*/true); @@ -1849,8 +1864,16 @@ if (recTy.getNumLenParams() != 0) TODO(loc, "reboxing descriptor of derived type with length parameters"); } - auto [boxTy, dest, eleSize] = - consDescriptorPrefix(rebox, rewriter, rebox.getOutRank(), lenParams); + + // Rebox on polymorphic entities needs to carry over the dynamic type. + mlir::Value typeDescAddr; + if (rebox.getBox().getType().isa() && + rebox.getType().isa()) + typeDescAddr = loadTypeDescAddress(loc, rebox.getBox().getType(), + loweredBox, rewriter); + + auto [boxTy, dest, eleSize] = consDescriptorPrefix( + rebox, rewriter, rebox.getOutRank(), lenParams, typeDescAddr); // Read input extents, strides, and base address llvm::SmallVector inputExtents; @@ -2835,7 +2858,7 @@ mlir::LogicalResult matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - if (auto boxTy = load.getType().dyn_cast()) { + if (auto boxTy = load.getType().dyn_cast()) { // fir.box is a special case because it is considered as an ssa values in // fir, but it is lowered as a pointer to a descriptor. So // fir.ref and fir.box end up being the same llvm types and diff --git a/flang/test/Lower/allocatable-polymorphic.f90 b/flang/test/Lower/allocatable-polymorphic.f90 --- a/flang/test/Lower/allocatable-polymorphic.f90 +++ b/flang/test/Lower/allocatable-polymorphic.f90 @@ -7,12 +7,14 @@ integer :: b contains procedure, nopass :: proc1 => proc1_p1 + procedure :: proc2 => proc2_p1 end type type, extends(p1) :: p2 integer :: c contains - procedure, nopass :: proc1 => proc1_p2 + procedure, nopass :: proc1 => proc1_p2 + procedure :: proc2 => proc2_p2 end type contains @@ -23,6 +25,16 @@ subroutine proc1_p2() print*, 'call proc1_p2' end subroutine + + subroutine proc2_p1(this) + class(p1) :: this + print*, 'call proc2_p1' + end subroutine + + subroutine proc2_p2(this) + class(p2) :: this + print*, 'call proc2_p2' + end subroutine end module program test_allocatable @@ -42,6 +54,9 @@ call c1%proc1() call c2%proc1() + + call c1%proc2() + call c2%proc2() end ! CHECK-LABEL: func.func @_QQmain() @@ -85,7 +100,6 @@ ! CHECK: %[[RANK:.*]] = arith.constant 1 : i32 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 ! CHECK: fir.call @_FortranAAllocatableInitDerived(%[[C3_CAST]], %[[TYPE_DESC_P1_CAST]], %[[RANK]], %[[C0]]) : (!fir.ref>, !fir.ref, i32, i32) -> none -! CHECK: %[[C1:.*]] = arith.constant 1 : index ! CHECK: %[[C10:.*]] = arith.constant 10 : i32 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 ! CHECK: %[[C3_CAST:.*]] = fir.convert %[[C3]] : (!fir.ref>>>>) -> !fir.ref> @@ -101,22 +115,34 @@ ! CHECK: %[[RANK:.*]] = arith.constant 1 : i32 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 ! CHECK: fir.call @_FortranAAllocatableInitDerived(%[[C4_CAST]], %[[TYPE_DESC_P2_CAST]], %[[RANK]], %[[C0]]) : (!fir.ref>, !fir.ref, i32, i32) -> none -! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[CST1:.*]] = arith.constant 1 : index ! CHECK: %[[C20:.*]] = arith.constant 20 : i32 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 ! CHECK: %[[C4_CAST:.*]] = fir.convert %[[C4]] : (!fir.ref>>>>) -> !fir.ref> -! CHECK: %[[C1_I64:.*]] = fir.convert %[[C1]] : (index) -> i64 +! CHECK: %[[C1_I64:.*]] = fir.convert %[[CST1]] : (index) -> i64 ! CHECK: %[[C20_I64:.*]] = fir.convert %[[C20]] : (i32) -> i64 ! CHECK: %{{.*}} = fir.call @_FortranAAllocatableSetBounds(%[[C4_CAST]], %[[C0]], %[[C1_I64]], %[[C20_I64]]) : (!fir.ref>, i32, i64, i64) -> none ! CHECK: %[[C4_CAST:.*]] = fir.convert %[[C4]] : (!fir.ref>>>>) -> !fir.ref> ! CHECK: %{{.*}} = fir.call @_FortranAAllocatableAllocate(%[[C4_CAST]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, i1, !fir.box, !fir.ref, i32) -> i32 +! Check fir.rebox for fir.class +! CHECK: %[[C1_LOAD:.*]] = fir.load %[[C1]] : !fir.ref>>> +! CHECK: %[[C1_REBOX:.*]] = fir.rebox %[[C1_LOAD]] : (!fir.class>>) -> !fir.class> +! CHECK: fir.dispatch "proc2"(%[[C1_REBOX]] : !fir.class>) (%61 : !fir.class>) {pass_arg_pos = 0 : i32} + +! CHECK: %[[C2_LOAD:.*]] = fir.load %[[C2]] : !fir.ref>>> +! CHECK: %[[C2_REBOX:.*]] = fir.rebox %[[C2_LOAD]] : (!fir.class>>) -> !fir.class> +! CHECK: fir.dispatch "proc2"(%[[C2_REBOX]] : !fir.class>) (%63 : !fir.class>) {pass_arg_pos = 0 : i32} ! Check code generation of allocate runtime calls for polymoprhic entities. This ! is done from Fortran so we don't have a file full of auto-generated type info ! in order to perform the checks. ! LLVM-LABEL: define void @_QQmain() +! LLVM: %[[TMP1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } +! LLVM: %[[TMP2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } +! LLVM: %[[TMP3:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } +! LLVM: %[[TMP4:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEp, ptr @_QMpolyE.dt.p1, i32 0, i32 0) ! LLVM: %{{.*}} = call i32 @_FortranAAllocatableAllocate(ptr @_QFEp, i1 false, ptr null, ptr @_QQcl.{{.*}}, i32 {{.*}}) ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEc1, ptr @_QMpolyE.dt.p1, i32 0, i32 0) @@ -129,3 +155,20 @@ ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEc4, ptr @_QMpolyE.dt.p2, i32 1, i32 0) ! LLVM: %{{.*}} = call {} @_FortranAAllocatableSetBounds(ptr @_QFEc4, i32 0, i64 1, i64 20) ! LLVM: %{{.*}} = call i32 @_FortranAAllocatableAllocate(ptr @_QFEc4, i1 false, ptr null, ptr @_QQcl.{{.*}}, i32 {{.*}}) +! LLVM: call void %{{.*}}() + +! LLVM: %[[C1_LOAD:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr @_QFEc1 +! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[C1_LOAD]], ptr %[[TMP4]] +! LLVM: %[[GEP_TDESC_C1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[TMP4]], i32 0, i32 7 +! LLVM: %[[TDESC_C1:.*]] = load ptr, ptr %[[GEP_TDESC_C1]] +! LLVM: %[[BOX0:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } { ptr undef, i64 ptrtoint (ptr getelementptr (%_QMpolyTp1, ptr null, i32 1) to i64), i32 20180515, i8 0, i8 42, i8 0, i8 1, ptr undef, [1 x i64] undef }, ptr %[[TDESC_C1]], 7 +! LLVM: %[[BOX1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX0]], ptr %{{.*}}, 0 +! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX1]], ptr %[[TMP3]] + +! LLVM: %[[LOAD_C2:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr @_QFEc2 +! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[LOAD_C2]], ptr %[[TMP2]] +! LLVM: %[[GEP_TDESC_C2:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[TMP2]], i32 0, i32 7 +! LLVM: %[[TDESC_C2:.*]] = load ptr, ptr %[[GEP_TDESC_C2]] +! LLVM: %[[BOX0:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } { ptr undef, i64 ptrtoint (ptr getelementptr (%_QMpolyTp1, ptr null, i32 1) to i64), i32 20180515, i8 0, i8 42, i8 0, i8 1, ptr undef, [1 x i64] undef }, ptr %[[TDESC_C2]], 7 +! LLVM: %[[BOX1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX0]], ptr %{{.*}}, 0 +! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX1]], ptr %[[TMP1]]