diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -38,6 +38,9 @@ /// Returns the element type of this box type. mlir::Type getEleTy() const; + /// Unwrap element type from fir.heap, fir.ptr and fir.array. + mlir::Type unwrapInnerType() const; + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(mlir::Type type); }; @@ -273,6 +276,10 @@ /// Return true iff `ty` is the type of an ALLOCATABLE entity or value. bool isAllocatableType(mlir::Type ty); +/// Return true iff `ty` is the type of a boxed record type. +/// e.g. !fir.box> +bool isBoxedRecordType(mlir::Type ty); + /// Return true iff `ty` is the type of an polymorphic entity or /// value. bool isPolymorphicType(mlir::Type ty); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -916,7 +916,8 @@ (isPointerCompatible(inType) && isIntegerCompatible(outType)) || (inType.isa() && outType.isa()) || (inType.isa() && outType.isa()) || - (fir::isa_complex(inType) && fir::isa_complex(outType))) + (fir::isa_complex(inType) && fir::isa_complex(outType)) || + (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType))) return mlir::success(); return emitOpError("invalid type conversion"); } diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -262,6 +262,18 @@ return false; } +bool isBoxedRecordType(mlir::Type ty) { + if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) + ty = refTy; + if (auto boxTy = ty.dyn_cast()) { + if (boxTy.getEleTy().isa()) + return true; + mlir::Type innerType = boxTy.unwrapInnerType(); + return innerType && innerType.isa(); + } + return false; +} + static bool isAssumedType(mlir::Type ty) { if (auto boxTy = ty.dyn_cast()) { if (boxTy.getEleTy().isa()) @@ -289,17 +301,8 @@ if (auto clTy = ty.dyn_cast()) { if (clTy.getEleTy().isa()) return true; - mlir::Type innerType = - llvm::TypeSwitch(clTy.getEleTy()) - .Case( - [](auto ty) { - mlir::Type eleTy = ty.getEleTy(); - if (auto seqTy = eleTy.dyn_cast()) - return seqTy.getEleTy(); - return eleTy; - }) - .Default([](mlir::Type) { return mlir::Type{}; }); - return innerType.isa(); + mlir::Type innerType = clTy.unwrapInnerType(); + return innerType && innerType.isa(); } // TYPE(*) return isAssumedType(ty); @@ -982,6 +985,17 @@ [](auto type) { return type.getEleTy(); }); } +mlir::Type BaseBoxType::unwrapInnerType() const { + return llvm::TypeSwitch(getEleTy()) + .Case([](auto ty) { + mlir::Type eleTy = ty.getEleTy(); + if (auto seqTy = eleTy.dyn_cast()) + return seqTy.getEleTy(); + return eleTy; + }) + .Default([](mlir::Type) { return mlir::Type{}; }); +} + //===----------------------------------------------------------------------===// // FIROpsDialect //===----------------------------------------------------------------------===// diff --git a/flang/test/Lower/polymorphic.f90 b/flang/test/Lower/polymorphic.f90 --- a/flang/test/Lower/polymorphic.f90 +++ b/flang/test/Lower/polymorphic.f90 @@ -6,6 +6,8 @@ type p1 integer :: a integer :: b + contains + procedure :: print end type type, extends(p1) :: p2 @@ -27,4 +29,25 @@ ! CHECK: %[[LOAD:.*]] = fir.load %[[COORD]] : !fir.ref ! CHECK: %{{.*}} = fir.call @_FortranAioOutputInteger32(%{{.*}}, %[[LOAD]]) : (!fir.ref, i32) -> i1 + subroutine print(this) + class(p1) :: this + end subroutine + + ! Test passing fir.convert accept fir.box
-> fir.class
+ subroutine check() + type(p1) :: t1 + type(p2) :: t2 + call t1%print() + call t2%print() + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_testPcheck() +! CHECK: %[[DT1:.*]] = fir.alloca !fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}> {bindc_name = "t1", uniq_name = "_QMpolymorphic_testFcheckEt1"} +! CHECK: %[[DT2:.*]] = fir.alloca !fir.type<_QMpolymorphic_testTp2{a:i32,b:i32,c:f32}> {bindc_name = "t2", uniq_name = "_QMpolymorphic_testFcheckEt2"} +! CHECK: %[[BOX1:.*]] = fir.embox %[[DT1]] : (!fir.ref>) -> !fir.box> +! CHECK: %[[CLASS1:.*]] = fir.convert %[[BOX1]] : (!fir.box>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPprint(%[[CLASS1]]) : (!fir.class>) -> () +! CHECK: %[[BOX2:.*]] = fir.embox %[[DT2]] : (!fir.ref>) -> !fir.box> +! CHECK: %[[CLASS2:.*]] = fir.convert %[[BOX2]] : (!fir.box>) -> !fir.class> +! CHECK: fir.call @_QMpolymorphic_testPprint(%[[CLASS2]]) : (!fir.class>) -> () end module diff --git a/flang/unittests/Optimizer/FIRTypesTest.cpp b/flang/unittests/Optimizer/FIRTypesTest.cpp --- a/flang/unittests/Optimizer/FIRTypesTest.cpp +++ b/flang/unittests/Optimizer/FIRTypesTest.cpp @@ -117,3 +117,32 @@ EXPECT_FALSE(fir::isUnlimitedPolymorphicType(noneTy)); EXPECT_FALSE(fir::isUnlimitedPolymorphicType(seqNoneTy)); } + +// Test fir::isBoxedRecordType from flang/Optimizer/Dialect/FIRType.h. +TEST_F(FIRTypesTest, isBoxedRecordType) { + mlir::Type recTy = fir::RecordType::get(&context, "dt"); + mlir::Type seqRecTy = + fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, recTy); + mlir::Type ty = fir::BoxType::get(recTy); + EXPECT_TRUE(fir::isBoxedRecordType(ty)); + EXPECT_TRUE(fir::isBoxedRecordType(fir::ReferenceType::get(ty))); + + // TYPE(T), ALLOCATABLE + ty = fir::BoxType::get(fir::HeapType::get(recTy)); + EXPECT_TRUE(fir::isBoxedRecordType(ty)); + + // TYPE(T), POINTER + ty = fir::BoxType::get(fir::PointerType::get(recTy)); + EXPECT_TRUE(fir::isBoxedRecordType(ty)); + + // TYPE(T), DIMENSION(10) + ty = fir::BoxType::get(fir::SequenceType::get({10}, recTy)); + EXPECT_TRUE(fir::isBoxedRecordType(ty)); + + // TYPE(T), DIMENSION(:) + ty = fir::BoxType::get(seqRecTy); + EXPECT_TRUE(fir::isBoxedRecordType(ty)); + + EXPECT_FALSE(fir::isBoxedRecordType(fir::BoxType::get( + fir::ReferenceType::get(mlir::IntegerType::get(&context, 32))))); +}