diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -137,8 +137,6 @@ // Types //===--------------------------------------------------------------------===// - /// Generate the type of a DataRef - virtual mlir::Type genType(const Fortran::evaluate::DataRef &) = 0; /// Generate the type of an Expr virtual mlir::Type genType(const SomeExpr &) = 0; /// Generate the type of a Symbol @@ -149,6 +147,8 @@ virtual mlir::Type genType(Fortran::common::TypeCategory tc, int kind, llvm::ArrayRef lenParameters = llvm::None) = 0; + /// Generate the type from a DerivedTypeSpec. + virtual mlir::Type genType(const Fortran::semantics::DerivedTypeSpec &) = 0; /// Generate the type from a Variable virtual mlir::Type genType(const pft::Variable &) = 0; diff --git a/flang/include/flang/Lower/ConvertType.h b/flang/include/flang/Lower/ConvertType.h --- a/flang/include/flang/Lower/ConvertType.h +++ b/flang/include/flang/Lower/ConvertType.h @@ -44,6 +44,7 @@ namespace semantics { class Symbol; +class DerivedTypeSpec; } // namespace semantics namespace lower { @@ -62,6 +63,11 @@ mlir::Type getFIRType(mlir::MLIRContext *ctxt, common::TypeCategory tc, int kind, llvm::ArrayRef); +/// Get a FIR type for a derived type +mlir::Type +translateDerivedTypeToFIRType(Fortran::lower::AbstractConverter &, + const Fortran::semantics::DerivedTypeSpec &); + /// Translate a SomeExpr to an mlir::Type. mlir::Type translateSomeExprToFIRType(Fortran::lower::AbstractConverter &, const SomeExpr &expr); diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -241,26 +241,26 @@ return foldingContext; } - mlir::Type genType(const Fortran::evaluate::DataRef &) override final { - TODO_NOLOC("Not implemented genType DataRef. Needed for more complex " - "expression lowering"); - } mlir::Type genType(const Fortran::lower::SomeExpr &expr) override final { return Fortran::lower::translateSomeExprToFIRType(*this, expr); } mlir::Type genType(Fortran::lower::SymbolRef sym) override final { return Fortran::lower::translateSymbolToFIRType(*this, sym); } - mlir::Type genType(Fortran::common::TypeCategory tc) override final { - TODO_NOLOC("Not implemented genType TypeCategory. Needed for more complex " - "expression lowering"); - } mlir::Type genType(Fortran::common::TypeCategory tc, int kind, llvm::ArrayRef lenParameters) override final { return Fortran::lower::getFIRType(&getMLIRContext(), tc, kind, lenParameters); } + mlir::Type + genType(const Fortran::semantics::DerivedTypeSpec &tySpec) override final { + return Fortran::lower::translateDerivedTypeToFIRType(*this, tySpec); + } + mlir::Type genType(Fortran::common::TypeCategory tc) override final { + TODO_NOLOC("Not implemented genType TypeCategory. Needed for more complex " + "expression lowering"); + } mlir::Type genType(const Fortran::lower::pft::Variable &var) override final { return Fortran::lower::translateVariableToFIRType(*this, var); } diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp --- a/flang/lib/Lower/CallInterface.cpp +++ b/flang/lib/Lower/CallInterface.cpp @@ -215,7 +215,11 @@ dynamicType.GetCharLength()) visitor(toEvExpr(*length)); } else if (dynamicType.category() == common::TypeCategory::Derived) { - TODO(converter.getCurrentLocation(), "walkResultLengths derived type"); + const Fortran::semantics::DerivedTypeSpec &derivedTypeSpec = + dynamicType.GetDerivedTypeSpec(); + if (Fortran::semantics::CountLenParameters(derivedTypeSpec) > 0) + TODO(converter.getCurrentLocation(), + "function result with derived type length parameters"); } } @@ -759,8 +763,10 @@ Fortran::common::TypeCategory cat = dynamicType.category(); // DERIVED if (cat == Fortran::common::TypeCategory::Derived) { - TODO(interface.converter.getCurrentLocation(), - "[translateDynamicType] Derived types"); + if (dynamicType.IsPolymorphic()) + TODO(interface.converter.getCurrentLocation(), + "[translateDynamicType] polymorphic types"); + return getConverter().genType(dynamicType.GetDerivedTypeSpec()); } // CHARACTER with compile time constant length. if (cat == Fortran::common::TypeCategory::Character) diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -1109,10 +1109,10 @@ } ExtValue gen(const Fortran::evaluate::DataRef &dref) { - TODO(getLoc(), "gen DataRef"); + return std::visit([&](const auto &x) { return gen(x); }, dref.u); } ExtValue genval(const Fortran::evaluate::DataRef &dref) { - TODO(getLoc(), "genval DataRef"); + return std::visit([&](const auto &x) { return genval(x); }, dref.u); } // Helper function to turn the Component structure into a list of nested @@ -1166,10 +1166,18 @@ } ExtValue gen(const Fortran::evaluate::Component &cmpt) { - TODO(getLoc(), "gen Component"); + // Components may be pointer or allocatable. In the gen() path, the mutable + // aspect is lost to simplify handling on the client side. To retain the + // mutable aspect, genMutableBoxValue should be used. + return genComponent(cmpt).match( + [&](const fir::MutableBoxValue &mutableBox) { + return fir::factory::genMutableBoxRead(builder, getLoc(), mutableBox); + }, + [](auto &box) -> ExtValue { return box; }); } + ExtValue genval(const Fortran::evaluate::Component &cmpt) { - TODO(getLoc(), "genval Component"); + return genLoad(gen(cmpt)); } ExtValue genval(const Fortran::semantics::Bound &bound) { @@ -1345,7 +1353,7 @@ mlir::Type genType(const Fortran::evaluate::DynamicType &dt) { if (dt.category() != Fortran::common::TypeCategory::Derived) return converter.genType(dt.category(), dt.kind()); - TODO(getLoc(), "genType Derived Type"); + return converter.genType(dt.GetDerivedTypeSpec()); } /// Lower a function reference diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp --- a/flang/lib/Lower/ConvertType.cpp +++ b/flang/lib/Lower/ConvertType.cpp @@ -8,6 +8,7 @@ #include "flang/Lower/ConvertType.h" #include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/Mangler.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Lower/Support/Utils.h" #include "flang/Lower/Todo.h" @@ -16,6 +17,7 @@ #include "flang/Semantics/type.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "flang-lower-type" @@ -139,7 +141,7 @@ mlir::Type baseType; if (category == Fortran::common::TypeCategory::Derived) { - TODO(converter.getCurrentLocation(), "genExprType derived"); + baseType = genDerivedType(dynamicType->GetDerivedTypeSpec()); } else { // LOGICAL, INTEGER, REAL, COMPLEX, CHARACTER llvm::SmallVector params; @@ -231,8 +233,9 @@ ty = genFIRType(context, tySpec->category(), kind, params); } else if (type->IsPolymorphic()) { TODO(loc, "genSymbolType polymorphic types"); - } else if (type->AsDerived()) { - TODO(loc, "genSymbolType derived type"); + } else if (const Fortran::semantics::DerivedTypeSpec *tySpec = + type->AsDerived()) { + ty = genDerivedType(*tySpec); } else { fir::emitFatalError(loc, "symbol's type must have a type spec"); } @@ -263,6 +266,71 @@ return ty; } + /// Does \p component has non deferred lower bounds that are not compile time + /// constant 1. + static bool componentHasNonDefaultLowerBounds( + const Fortran::semantics::Symbol &component) { + if (const auto *objDetails = + component.detailsIf()) + for (const Fortran::semantics::ShapeSpec &bounds : objDetails->shape()) + if (auto lb = bounds.lbound().GetExplicit()) + if (auto constant = Fortran::evaluate::ToInt64(*lb)) + if (!constant || *constant != 1) + return true; + return false; + } + + mlir::Type genDerivedType(const Fortran::semantics::DerivedTypeSpec &tySpec) { + std::vector> ps; + std::vector> cs; + const Fortran::semantics::Symbol &typeSymbol = tySpec.typeSymbol(); + if (mlir::Type ty = getTypeIfDerivedAlreadyInConstruction(typeSymbol)) + return ty; + auto rec = fir::RecordType::get(context, + Fortran::lower::mangle::mangleName(tySpec)); + // Maintain the stack of types for recursive references. + derivedTypeInConstruction.emplace_back(typeSymbol, rec); + + // Gather the record type fields. + // (1) The data components. + for (const auto &field : + Fortran::semantics::OrderedComponentIterator(tySpec)) { + // Lowering is assuming non deferred component lower bounds are always 1. + // Catch any situations where this is not true for now. + if (componentHasNonDefaultLowerBounds(field)) + TODO(converter.genLocation(field.name()), + "lowering derived type components with non default lower bounds"); + if (IsProcName(field)) + TODO(converter.genLocation(field.name()), "procedure components"); + mlir::Type ty = genSymbolType(field); + // Do not add the parent component (component of the parents are + // added and should be sufficient, the parent component would + // duplicate the fields). + if (field.test(Fortran::semantics::Symbol::Flag::ParentComp)) + continue; + cs.emplace_back(field.name().ToString(), ty); + } + + // (2) The LEN type parameters. + for (const auto ¶m : + Fortran::semantics::OrderParameterDeclarations(typeSymbol)) + if (param->get().attr() == + Fortran::common::TypeParamAttr::Len) + ps.emplace_back(param->name().ToString(), genSymbolType(*param)); + + rec.finalize(ps, cs); + popDerivedTypeInConstruction(); + + if (!ps.empty()) { + // This type is a PDT (parametric derived type). Create the functions to + // use for allocation, dereferencing, and address arithmetic here. + TODO(converter.genLocation(typeSymbol.name()), + "parametrized derived types lowering"); + } + LLVM_DEBUG(llvm::dbgs() << "derived type: " << rec << '\n'); + return rec; + } + // To get the character length from a symbol, make an fold a designator for // the symbol to cover the case where the symbol is an assumed length named // constant and its length comes from its init expression length. @@ -326,7 +394,27 @@ return genSymbolType(var.getSymbol(), var.isHeapAlloc(), var.isPointer()); } -private: + /// Derived type can be recursive. That is, pointer components of a derived + /// type `t` have type `t`. This helper returns `t` if it is already being + /// lowered to avoid infinite loops. + mlir::Type getTypeIfDerivedAlreadyInConstruction( + const Fortran::lower::SymbolRef derivedSym) const { + for (const auto &[sym, type] : derivedTypeInConstruction) + if (sym == derivedSym) + return type; + return {}; + } + + void popDerivedTypeInConstruction() { + assert(!derivedTypeInConstruction.empty()); + derivedTypeInConstruction.pop_back(); + } + + /// Stack derived type being processed to avoid infinite loops in case of + /// recursive derived types. The depth of derived types is expected to be + /// shallow (<10), so a SmallVector is sufficient. + llvm::SmallVector> + derivedTypeInConstruction; Fortran::lower::AbstractConverter &converter; mlir::MLIRContext *context; }; @@ -340,6 +428,12 @@ return genFIRType(context, tc, kind, params); } +mlir::Type Fortran::lower::translateDerivedTypeToFIRType( + Fortran::lower::AbstractConverter &converter, + const Fortran::semantics::DerivedTypeSpec &tySpec) { + return TypeBuilder{converter}.genDerivedType(tySpec); +} + mlir::Type Fortran::lower::translateSomeExprToFIRType( Fortran::lower::AbstractConverter &converter, const SomeExpr &expr) { return TypeBuilder{converter}.genExprType(expr); diff --git a/flang/test/Lower/derived-types.f90 b/flang/test/Lower/derived-types.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/derived-types.f90 @@ -0,0 +1,195 @@ +! Test basic parts of derived type entities lowering +! RUN: bbc -emit-fir %s -o - | FileCheck %s + +! Note: only testing non parametrized derived type here. + +module d + type r + real :: x + end type + type r2 + real :: x_array(10, 20) + end type + type c + character(10) :: ch + end type + type c2 + character(10) :: ch_array(20, 30) + end type + contains + + ! ----------------------------------------------------------------------------- + ! Test simple derived type symbol lowering + ! ----------------------------------------------------------------------------- + + ! CHECK-LABEL: func @_QMdPderived_dummy( + ! CHECK-SAME: %{{.*}}: !fir.ref>{{.*}}, %{{.*}}: !fir.ref>}>>{{.*}}) { + subroutine derived_dummy(some_r, some_c2) + type(r) :: some_r + type(c2) :: some_c2 + end subroutine + + ! CHECK-LABEL: func @_QMdPlocal_derived( + subroutine local_derived() + ! CHECK-DAG: fir.alloca !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}> + ! CHECK-DAG: fir.alloca !fir.type<_QMdTr{x:f32}> + type(r) :: some_r + type(c2) :: some_c2 + end subroutine + + ! CHECK-LABEL: func @_QMdPsaved_derived( + subroutine saved_derived() + ! CHECK-DAG: fir.address_of(@_QMdFsaved_derivedEsome_c2) : !fir.ref>}>> + ! CHECK-DAG: fir.address_of(@_QMdFsaved_derivedEsome_r) : !fir.ref> + type(r), save :: some_r + type(c2), save :: some_c2 + call use_symbols(some_r, some_c2) + end subroutine + + + ! ----------------------------------------------------------------------------- + ! Test simple derived type references + ! ----------------------------------------------------------------------------- + + ! CHECK-LABEL: func @_QMdPscalar_numeric_ref( + subroutine scalar_numeric_ref() + ! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}> + type(r) :: some_r + ! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}> + ! CHECK: fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref>, !fir.field) -> !fir.ref + call real_bar(some_r%x) + end subroutine + + ! CHECK-LABEL: func @_QMdPscalar_character_ref( + subroutine scalar_character_ref() + ! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTc{ch:!fir.char<1,10>}> + type(c) :: some_c + ! CHECK: %[[field:.*]] = fir.field_index ch, !fir.type<_QMdTc{ch:!fir.char<1,10>}> + ! CHECK: %[[coor:.*]] = fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref}>>, !fir.field) -> !fir.ref> + ! CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index + ! CHECK-DAG: %[[conv:.*]] = fir.convert %[[coor]] : (!fir.ref>) -> !fir.ref> + ! CHECK: fir.emboxchar %[[conv]], %c10 : (!fir.ref>, index) -> !fir.boxchar<1> + call char_bar(some_c%ch) + end subroutine + + ! FIXME: coordinate of generated for derived%array_comp(i) are not zero based as they + ! should be. + + ! CHECK-LABEL: func @_QMdParray_comp_elt_ref( + subroutine array_comp_elt_ref() + type(r2) :: some_r2 + ! CHECK: %[[alloc:.*]] = fir.alloca !fir.type<_QMdTr2{x_array:!fir.array<10x20xf32>}> + ! CHECK: %[[field:.*]] = fir.field_index x_array, !fir.type<_QMdTr2{x_array:!fir.array<10x20xf32>}> + ! CHECK: %[[coor:.*]] = fir.coordinate_of %[[alloc]], %[[field]] : (!fir.ref}>>, !fir.field) -> !fir.ref> + ! CHECK-DAG: %[[index1:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64 + ! CHECK-DAG: %[[index2:.*]] = arith.subi %c6{{.*}}, %c1{{.*}} : i64 + ! CHECK: fir.coordinate_of %[[coor]], %[[index1]], %[[index2]] : (!fir.ref>, i64, i64) -> !fir.ref + call real_bar(some_r2%x_array(5, 6)) + end subroutine + + + ! CHECK-LABEL: func @_QMdPchar_array_comp_elt_ref( + subroutine char_array_comp_elt_ref() + type(c2) :: some_c2 + ! CHECK: %[[coor:.*]] = fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref>}>>, !fir.field) -> !fir.ref>> + ! CHECK-DAG: %[[index1:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64 + ! CHECK-DAG: %[[index2:.*]] = arith.subi %c6{{.*}}, %c1{{.*}} : i64 + ! CHECK: fir.coordinate_of %[[coor]], %[[index1]], %[[index2]] : (!fir.ref>>, i64, i64) -> !fir.ref> + ! CHECK: fir.emboxchar %{{.*}}, %c10 : (!fir.ref>, index) -> !fir.boxchar<1> + call char_bar(some_c2%ch_array(5, 6)) + end subroutine + + ! CHECK: @_QMdParray_elt_comp_ref + subroutine array_elt_comp_ref() + type(r) :: some_r_array(100) + ! CHECK: %[[alloca:.*]] = fir.alloca !fir.array<100x!fir.type<_QMdTr{x:f32}>> + ! CHECK: %[[index:.*]] = arith.subi %c5{{.*}}, %c1{{.*}} : i64 + ! CHECK: %[[elt:.*]] = fir.coordinate_of %[[alloca]], %[[index]] : (!fir.ref>>, i64) -> !fir.ref> + ! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}> + ! CHECK: fir.coordinate_of %[[elt]], %[[field]] : (!fir.ref>, !fir.field) -> !fir.ref + call real_bar(some_r_array(5)%x) + end subroutine + + ! CHECK: @_QMdPchar_array_elt_comp_ref + subroutine char_array_elt_comp_ref() + type(c) :: some_c_array(100) + ! CHECK: fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref}>>>, i64) -> !fir.ref}>> + ! CHECK: fir.coordinate_of %{{.*}}, %{{.*}} : (!fir.ref}>>, !fir.field) -> !fir.ref> + ! CHECK: fir.emboxchar %{{.*}}, %c10{{.*}} : (!fir.ref>, index) -> !fir.boxchar<1> + call char_bar(some_c_array(5)%ch) + end subroutine + + ! ----------------------------------------------------------------------------- + ! Test loading derived type components + ! ----------------------------------------------------------------------------- + + ! Most of the other tests only require lowering code to compute the address of + ! components. This one requires loading a component which tests other code paths + ! in lowering. + + ! CHECK-LABEL: func @_QMdPscalar_numeric_load( + ! CHECK-SAME: %[[arg0:.*]]: !fir.ref> + real function scalar_numeric_load(some_r) + type(r) :: some_r + ! CHECK: %[[field:.*]] = fir.field_index x, !fir.type<_QMdTr{x:f32}> + ! CHECK: %[[coor:.*]] = fir.coordinate_of %[[arg0]], %[[field]] : (!fir.ref>, !fir.field) -> !fir.ref + ! CHECK: fir.load %[[coor]] + scalar_numeric_load = some_r%x + end function + + ! ----------------------------------------------------------------------------- + ! Test returned derived types (no length parameters) + ! ----------------------------------------------------------------------------- + + ! CHECK-LABEL: func @_QMdPbar_return_derived() -> !fir.type<_QMdTr{x:f32}> + function bar_return_derived() + ! CHECK: %[[res:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}> + type(r) :: bar_return_derived + ! CHECK: %[[resLoad:.*]] = fir.load %[[res]] : !fir.ref> + ! CHECK: return %[[resLoad]] : !fir.type<_QMdTr{x:f32}> + end function + + ! CHECK-LABEL: func @_QMdPcall_bar_return_derived( + subroutine call_bar_return_derived() + ! CHECK: %[[tmp:.*]] = fir.alloca !fir.type<_QMdTr{x:f32}> + ! CHECK: %[[call:.*]] = fir.call @_QMdPbar_return_derived() : () -> !fir.type<_QMdTr{x:f32}> + ! CHECK: fir.save_result %[[call]] to %[[tmp]] : !fir.type<_QMdTr{x:f32}>, !fir.ref> + ! CHECK: fir.call @_QPr_bar(%[[tmp]]) : (!fir.ref>) -> () + call r_bar(bar_return_derived()) + end subroutine + + end module + + ! ----------------------------------------------------------------------------- + ! Test derived type with pointer/allocatable components + ! ----------------------------------------------------------------------------- + + module d2 + type recursive_t + real :: x + type(recursive_t), pointer :: ptr + end type + contains + ! CHECK-LABEL: func @_QMd2Ptest_recursive_type( + ! CHECK-SAME: %{{.*}}: !fir.ref>>}>>{{.*}}) { + subroutine test_recursive_type(some_recursive) + type(recursive_t) :: some_recursive + end subroutine + end module + + ! ----------------------------------------------------------------------------- + ! Test global derived type symbol lowering + ! ----------------------------------------------------------------------------- + + module data_mod + use d + type(r) :: some_r + type(c2) :: some_c2 + end module + + ! Test globals + + ! CHECK-DAG: fir.global @_QMdata_modEsome_c2 : !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}> + ! CHECK-DAG: fir.global @_QMdata_modEsome_r : !fir.type<_QMdTr{x:f32}> + ! CHECK-DAG: fir.global internal @_QMdFsaved_derivedEsome_c2 : !fir.type<_QMdTc2{ch_array:!fir.array<20x30x!fir.char<1,10>>}> + ! CHECK-DAG: fir.global internal @_QMdFsaved_derivedEsome_r : !fir.type<_QMdTr{x:f32}>