diff --git a/flang/include/flang/Optimizer/Builder/BoxValue.h b/flang/include/flang/Optimizer/Builder/BoxValue.h --- a/flang/include/flang/Optimizer/Builder/BoxValue.h +++ b/flang/include/flang/Optimizer/Builder/BoxValue.h @@ -194,11 +194,11 @@ llvm::ArrayRef extents) : AbstractBox{addr}, AbstractArrayBox(extents, lbounds) {} /// Get the fir.box part of the address type. - fir::BoxType getBoxTy() const { + fir::BaseBoxType getBoxTy() const { auto type = getAddr().getType(); if (auto pointedTy = fir::dyn_cast_ptrEleTy(type)) type = pointedTy; - return type.cast(); + return type.cast(); } /// Return the part of the address type after memory and box types. That is /// the element type, maybe wrapped in a fir.array type. diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -762,7 +762,7 @@ OptionalAttr:$accessMap ); - let results = (outs fir_BoxType); + let results = (outs BoxOrClassType); let builders = [ OpBuilder<(ins "llvm::ArrayRef":$resultTypes, diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -87,7 +87,7 @@ /// Is `t` a boxed type? inline bool isa_box_type(mlir::Type t) { - return t.isa(); + return t.isa(); } /// Is `t` a type that is always trivially pass-by-reference? Specifically, this @@ -307,6 +307,14 @@ return type.isa(); } +/// Return a fir.box or fir.class if the type is polymorphic. +inline mlir::Type wrapInClassOrBoxType(mlir::Type eleTy, + bool isPolymorphic = false) { + if (isPolymorphic) + return fir::ClassType::get(eleTy); + return fir::BoxType::get(eleTy); +} + } // namespace fir diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -564,6 +564,10 @@ let genStorageClass = 0; } +// Whether a type is a BaseBoxType +def IsBaseBoxTypePred + : CPred<"$_self.isa<::fir::BaseBoxType>()">; + // Generalized FIR and standard dialect types representing intrinsic types def AnyIntegerLike : TypeConstraint, "any integer">; @@ -596,7 +600,11 @@ fir_LLVMPointerType.predicate]>, "fir.ref or fir.llvm_ptr">; def AnyBoxLike : TypeConstraint, "any box">; + fir_BoxCharType.predicate, fir_BoxProcType.predicate, + fir_ClassType.predicate]>, "any box">; + +def BoxOrClassType : TypeConstraint, "box or class">; def AnyRefOrBoxLike : TypeConstraint, diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h --- a/flang/include/flang/Semantics/tools.h +++ b/flang/include/flang/Semantics/tools.h @@ -183,6 +183,7 @@ const Scope &, bool vectorSubscriptIsOk = false); const Symbol *IsExternalInPureContext(const Symbol &, const Scope &); bool HasCoarray(const parser::Expr &); +bool IsPolymorphic(const Symbol &); bool IsPolymorphicAllocatable(const Symbol &); // Return an error if component symbol is not accessible from scope (7.5.4.8(2)) std::optional CheckAccessibleComponent( diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -797,9 +797,8 @@ Fortran::common::TypeCategory cat = dynamicType.category(); // DERIVED if (cat == Fortran::common::TypeCategory::Derived) { - if (dynamicType.IsPolymorphic()) - TODO(interface.converter.getCurrentLocation(), - "support for polymorphic types"); + if (dynamicType.IsUnlimitedPolymorphic()) + return mlir::NoneType::get(&mlirContext); return getConverter().genType(dynamicType.GetDerivedTypeSpec()); } // CHARACTER with compile time constant length. @@ -860,16 +859,17 @@ type = fir::HeapType::get(type); if (obj.attrs.test(Attrs::Pointer)) type = fir::PointerType::get(type); - mlir::Type boxType = fir::BoxType::get(type); + mlir::Type boxType = + fir::wrapInClassOrBoxType(type, obj.type.type().IsPolymorphic()); if (obj.attrs.test(Attrs::Allocatable) || obj.attrs.test(Attrs::Pointer)) { - // Pass as fir.ref + // Pass as fir.ref or fir.ref mlir::Type boxRefType = fir::ReferenceType::get(boxType); addFirOperand(boxRefType, nextPassedArgPosition(), Property::MutableBox, attrs); addPassedArg(PassEntityBy::MutableBox, entity, characteristics); } else if (dummyRequiresBox(obj)) { - // Pass as fir.box + // Pass as fir.box or fir.class if (isValueAttr) TODO(loc, "assumed shape dummy argument with VALUE attribute"); addFirOperand(boxType, nextPassedArgPosition(), Property::Box, attrs); @@ -954,12 +954,17 @@ assert(typeAndShape && "expect type for non proc pointer result"); mlir::Type mlirType = translateDynamicType(typeAndShape->type()); fir::SequenceType::Shape bounds = getBounds(typeAndShape->shape()); + const auto *resTypeAndShape{result.GetTypeAndShape()}; + bool resIsPolymorphic = + resTypeAndShape && resTypeAndShape->type().IsPolymorphic(); if (!bounds.empty()) mlirType = fir::SequenceType::get(bounds, mlirType); if (result.attrs.test(Attr::Allocatable)) - mlirType = fir::BoxType::get(fir::HeapType::get(mlirType)); + mlirType = fir::wrapInClassOrBoxType(fir::HeapType::get(mlirType), + resIsPolymorphic); if (result.attrs.test(Attr::Pointer)) - mlirType = fir::BoxType::get(fir::PointerType::get(mlirType)); + mlirType = fir::wrapInClassOrBoxType(fir::PointerType::get(mlirType), + resIsPolymorphic); if (fir::isa_char(mlirType)) { // Character scalar results must be passed as arguments in lowering so diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -2390,10 +2390,10 @@ llvm::ArrayRef extents, llvm::ArrayRef lengths) { mlir::Type type = base.getType(); - if (type.isa()) + if (type.isa()) return fir::BoxValue(base, /*lbounds=*/{}, lengths, extents); type = fir::unwrapRefType(type); - if (type.isa()) + if (type.isa()) return fir::MutableBoxValue(base, lengths, /*mutableProperties*/ {}); if (auto seqTy = type.dyn_cast()) { if (seqTy.getDimension() != extents.size()) diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp --- a/flang/lib/Lower/ConvertType.cpp +++ b/flang/lib/Lower/ConvertType.cpp @@ -233,8 +233,8 @@ llvm::SmallVector params; translateLenParameters(params, tySpec->category(), ultimate); ty = genFIRType(context, tySpec->category(), kind, params); - } else if (type->IsPolymorphic()) { - TODO(loc, "support for polymorphic types"); + } else if (type->IsUnlimitedPolymorphic()) { + ty = mlir::NoneType::get(context); } else if (const Fortran::semantics::DerivedTypeSpec *tySpec = type->AsDerived()) { ty = genDerivedType(*tySpec); @@ -253,11 +253,12 @@ translateShape(shape, std::move(*shapeExpr)); ty = fir::SequenceType::get(shape, ty); } - if (Fortran::semantics::IsPointer(symbol)) - return fir::BoxType::get(fir::PointerType::get(ty)); + return fir::wrapInClassOrBoxType( + fir::PointerType::get(ty), Fortran::semantics::IsPolymorphic(symbol)); if (Fortran::semantics::IsAllocatable(symbol)) - return fir::BoxType::get(fir::HeapType::get(ty)); + return fir::wrapInClassOrBoxType( + fir::HeapType::get(ty), Fortran::semantics::IsPolymorphic(symbol)); // isPtr and isAlloc are variable that were promoted to be on the // heap or to be pointers, but they do not have Fortran allocatable // or pointer semantics, so do not use box for them. diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -217,7 +217,7 @@ fir::ExtendedValue exv = globalOpSymMap.lookupSymbol(sym).toExtendedValue(); const auto *mold = exv.getBoxOf(); - fir::BoxType boxType = mold->getBoxTy(); + fir::BaseBoxType boxType = mold->getBoxTy(); mlir::Value box = fir::factory::createUnallocatedBox(builder, loc, boxType, {}); return box; @@ -1650,7 +1650,7 @@ mlir::Value argBox; mlir::Type castTy = builder.getRefType(varType); if (addr) { - if (auto boxTy = addr.getType().dyn_cast()) { + if (auto boxTy = addr.getType().dyn_cast()) { argBox = addr; mlir::Type refTy = builder.getRefType(boxTy.getEleTy()); addr = builder.create(loc, refTy, argBox); diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -3704,7 +3704,7 @@ "MOLD argument required to lower NULL outside of any context"); const auto *mold = args[0].getBoxOf(); assert(mold && "MOLD must be a pointer or allocatable"); - fir::BoxType boxType = mold->getBoxTy(); + fir::BaseBoxType boxType = mold->getBoxTy(); mlir::Value boxStorage = builder.createTemporary(loc, boxType); mlir::Value box = fir::factory::createUnallocatedBox( builder, loc, boxType, mold->nonDeferredLenParams()); diff --git a/flang/lib/Optimizer/Builder/BoxValue.cpp b/flang/lib/Optimizer/Builder/BoxValue.cpp --- a/flang/lib/Optimizer/Builder/BoxValue.cpp +++ b/flang/lib/Optimizer/Builder/BoxValue.cpp @@ -185,7 +185,7 @@ mlir::Type type = fir::dyn_cast_ptrEleTy(getAddr().getType()); if (!type) return false; - auto box = type.dyn_cast(); + auto box = type.dyn_cast(); if (!box) return false; // A boxed value always takes a memory reference, diff --git a/flang/lib/Optimizer/Builder/MutableBox.cpp b/flang/lib/Optimizer/Builder/MutableBox.cpp --- a/flang/lib/Optimizer/Builder/MutableBox.cpp +++ b/flang/lib/Optimizer/Builder/MutableBox.cpp @@ -320,7 +320,7 @@ fir::factory::createUnallocatedBox(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type boxType, mlir::ValueRange nonDeferredParams) { - auto baseAddrType = boxType.dyn_cast().getEleTy(); + auto baseAddrType = boxType.dyn_cast().getEleTy(); if (!fir::isa_ref_type(baseAddrType)) baseAddrType = builder.getRefType(baseAddrType); auto type = fir::unwrapRefType(baseAddrType); diff --git a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -13,6 +13,7 @@ #include "flang/Optimizer/CodeGen/CodeGen.h" #include "CGOps.h" +#include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" @@ -84,6 +85,8 @@ // If the embox does not include a shape, then do not convert it if (auto shapeVal = embox.getShape()) return rewriteDynamicShape(embox, rewriter, shapeVal); + if (embox.getType().isa()) + TODO(embox.getLoc(), "embox conversion for fir.class type"); if (auto boxTy = embox.getType().dyn_cast()) if (auto seqTy = boxTy.getEleTy().dyn_cast()) if (!seqTy.hasDynamicExtents()) @@ -274,6 +277,8 @@ target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::EmboxOp embox) { + if (embox.getType().isa()) + TODO(embox.getLoc(), "fir.class type CodeGenRewrite"); return !(embox.getShape() || embox.getType() .cast() .getEleTy() diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h --- a/flang/lib/Optimizer/CodeGen/TypeConverter.h +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -64,6 +64,10 @@ // procedure pointer feature is implemented. return llvm::None; }); + addConversion([&](fir::ClassType classTy) { + TODO_NOLOC("fir.class type conversion"); + return llvm::None; + }); addConversion( [&](fir::CharacterType charTy) { return convertCharType(charTy); }); addConversion( diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -209,7 +209,7 @@ return llvm::TypeSwitch(t) .Case([](auto p) { return p.getEleTy(); }) - .Case([](auto p) { + .Case([](auto p) { auto eleTy = p.getEleTy(); if (auto ty = fir::dyn_cast_ptrEleTy(eleTy)) return ty; @@ -249,7 +249,7 @@ bool isPointerType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; - if (auto boxTy = ty.dyn_cast()) + if (auto boxTy = ty.dyn_cast()) return boxTy.getEleTy().isa(); return false; } @@ -257,7 +257,7 @@ bool isAllocatableType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; - if (auto boxTy = ty.dyn_cast()) + if (auto boxTy = ty.dyn_cast()) return boxTy.getEleTy().isa(); return false; } @@ -265,8 +265,8 @@ bool isUnlimitedPolymorphicType(mlir::Type ty) { if (auto refTy = fir::dyn_cast_ptrEleTy(ty)) ty = refTy; - if (auto boxTy = ty.dyn_cast()) - return boxTy.getEleTy().isa(); + if (auto clTy = ty.dyn_cast()) + return clTy.getEleTy().isa(); return false; } diff --git a/flang/test/Lower/polymorphic-types.f90 b/flang/test/Lower/polymorphic-types.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/polymorphic-types.f90 @@ -0,0 +1,157 @@ +! RUN: bbc -emit-fir %s -o - | FileCheck %s + +! Tests the different possible type involving polymorphic entities. + +module polymorphic_types + type p1 + integer :: a + integer :: b + contains + procedure :: polymorphic_dummy + end type +contains + +! ------------------------------------------------------------------------------ +! Test polymorphic entity types +! ------------------------------------------------------------------------------ + + subroutine polymorphic_dummy(p) + class(p1) :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_dummy( +! CHECK-SAME: %{{.*}}: !fir.class> + + subroutine polymorphic_dummy_assumed_shape_array(pa) + class(p1) :: pa(:) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_dummy_assumed_shape_array( +! CHECK-SAME: %{{.*}}: !fir.class>> + + subroutine polymorphic_dummy_explicit_shape_array(pa) + class(p1) :: pa(10) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_dummy_explicit_shape_array( +! CHECK-SAME: %{{.*}}: !fir.class>> + + subroutine polymorphic_allocatable(p) + class(p1), allocatable :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_allocatable( +! CHECK-SAME: %{{.*}}: !fir.ref>>> + + subroutine polymorphic_pointer(p) + class(p1), pointer :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_pointer( +! CHECK-SAME: %{{.*}}: !fir.ref>>> + + subroutine polymorphic_allocatable_intentout(p) + class(p1), allocatable, intent(out) :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPpolymorphic_allocatable_intentout( +! CHECK-SAME: %[[ARG0:.*]]: !fir.ref>>> +! CHECK: %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.ref>>>) -> !fir.ref> +! CHECK: %{{.*}} = fir.call @_FortranAAllocatableDeallocate(%[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref>, i1, !fir.box, !fir.ref, i32) -> i32 + +! ------------------------------------------------------------------------------ +! Test unlimited polymorphic dummy argument types +! ------------------------------------------------------------------------------ + + subroutine unlimited_polymorphic_dummy(u) + class(*) :: u + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPunlimited_polymorphic_dummy( +! CHECK-SAME: %{{.*}}: !fir.class + + subroutine unlimited_polymorphic_assumed_shape_array(ua) + class(*) :: ua(:) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPunlimited_polymorphic_assumed_shape_array( +! CHECK-SAME: %{{.*}}: !fir.class> + + subroutine unlimited_polymorphic_explicit_shape_array(ua) + class(*) :: ua(20) + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPunlimited_polymorphic_explicit_shape_array( +! CHECK-SAME: %{{.*}}: !fir.class> + + subroutine unlimited_polymorphic_allocatable(p) + class(*), allocatable :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPunlimited_polymorphic_allocatable( +! CHECK-SAME: %{{.*}}: !fir.ref>> + + subroutine unlimited_polymorphic_pointer(p) + class(*), pointer :: p + end subroutine + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPunlimited_polymorphic_pointer( +! CHECK-SAME: %{{.*}}: !fir.ref>> + +! ------------------------------------------------------------------------------ +! Test polymorphic function return types +! ------------------------------------------------------------------------------ + + function ret_polymorphic_allocatable() result(ret) + class(p1), allocatable :: ret + end function + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPret_polymorphic_allocatable() -> !fir.class>> +! CHECK: %[[MEM:.*]] = fir.alloca !fir.class>> {bindc_name = "ret", uniq_name = "_QMpolymorphic_typesFret_polymorphic_allocatableEret"} +! CHECK: %[[ZERO:.*]] = fir.zero_bits !fir.heap> +! CHECK: %[[BOX:.*]] = fir.embox %[[ZERO]] : (!fir.heap>) -> !fir.class>> +! CHECK: fir.store %[[BOX]] to %[[MEM]] : !fir.ref>>> +! CHECK: %[[LOAD:.*]] = fir.load %[[MEM]] : !fir.ref>>> +! CHECK: return %[[LOAD]] : !fir.class>> + + function ret_polymorphic_pointer() result(ret) + class(p1), pointer :: ret + end function + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPret_polymorphic_pointer() -> !fir.class>> +! CHECK: %[[MEM:.*]] = fir.alloca !fir.class>> {bindc_name = "ret", uniq_name = "_QMpolymorphic_typesFret_polymorphic_pointerEret"} +! CHECK: %[[ZERO:.*]] = fir.zero_bits !fir.ptr> +! CHECK: %[[BOX:.*]] = fir.embox %[[ZERO]] : (!fir.ptr>) -> !fir.class>> +! CHECK: fir.store %[[BOX]] to %[[MEM]] : !fir.ref>>> +! CHECK: %[[LOAD:.*]] = fir.load %[[MEM]] : !fir.ref>>> +! CHECK: return %[[LOAD]] : !fir.class>> + +! ------------------------------------------------------------------------------ +! Test unlimited polymorphic function return types +! ------------------------------------------------------------------------------ + + function ret_unlimited_polymorphic_allocatable() result(ret) + class(*), allocatable :: ret + end function + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPret_unlimited_polymorphic_allocatable() -> !fir.class> +! CHECK: %[[MEM:.*]] = fir.alloca !fir.class> {bindc_name = "ret", uniq_name = "_QMpolymorphic_typesFret_unlimited_polymorphic_allocatableEret"} +! CHECK: %[[ZERO:.*]] = fir.zero_bits !fir.heap +! CHECK: %[[BOX:.*]] = fir.embox %[[ZERO]] : (!fir.heap) -> !fir.class> +! CHECK: fir.store %[[BOX]] to %[[MEM]] : !fir.ref>> +! CHECK: %[[LOAD:.*]] = fir.load %[[MEM]] : !fir.ref>> +! CHECK: return %[[LOAD]] : !fir.class> + + function ret_unlimited_polymorphic_pointer() result(ret) + class(*), pointer :: ret + end function + +! CHECK-LABEL: func.func @_QMpolymorphic_typesPret_unlimited_polymorphic_pointer() -> !fir.class> +! CHECK: %[[MEM:.*]] = fir.alloca !fir.class> {bindc_name = "ret", uniq_name = "_QMpolymorphic_typesFret_unlimited_polymorphic_pointerEret"} +! CHECK: %[[ZERO:.*]] = fir.zero_bits !fir.ptr +! CHECK: %[[BOX:.*]] = fir.embox %[[ZERO]] : (!fir.ptr) -> !fir.class> +! CHECK: fir.store %[[BOX]] to %[[MEM]] : !fir.ref>> +! CHECK: %[[LOAD:.*]] = fir.load %[[MEM]] : !fir.ref>> +! CHECK: return %[[LOAD]] : !fir.class> + +end module