Index: flang/include/flang/Lower/CallInterface.h =================================================================== --- flang/include/flang/Lower/CallInterface.h +++ flang/include/flang/Lower/CallInterface.h @@ -414,6 +414,10 @@ const Fortran::evaluate::ProcedureDesignator &proc, Fortran::lower::AbstractConverter &); +/// Translate the type for derived type passed by value. +mlir::Type classifyDerivedValueArgumentType(Fortran::lower::AbstractConverter &, + mlir::Type); + } // namespace Fortran::lower #endif // FORTRAN_LOWER_FIRBUILDER_H Index: flang/lib/Lower/CallInterface.cpp =================================================================== --- flang/lib/Lower/CallInterface.cpp +++ flang/lib/Lower/CallInterface.cpp @@ -18,9 +18,11 @@ #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Support/FIRContext.h" #include "flang/Optimizer/Support/InternalNames.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "llvm/ADT/Triple.h" //===----------------------------------------------------------------------===// // BIND(C) mangling helpers @@ -886,7 +888,11 @@ if (isBindC) { passBy = PassEntityBy::Value; prop = Property::Value; - passType = type; + if (fir::isa_derived(type)) + passType = Fortran::lower::classifyDerivedValueArgumentType( + interface.converter, type); + else + passType = type; } else { passBy = PassEntityBy::BaseAddressValueAttribute; } @@ -1239,3 +1245,232 @@ return fir::factory::getCharacterProcedureTupleType(procType); return procType; } + +//===--------------------------------------------------------------------===// +// ABI helpers of a call or function for argument and return type conversion. +// This is for dervied type passed by value (interoperable with C) for now. +// The derived type cannot have type parameters, TBP, pointer or allocatables. +// That is, the drived type has compile-time known size. +// Note: +// This must be consistent with ABI info in clang/lib/CodeGen/TargetInfo.cpp. +//===--------------------------------------------------------------------===// + +static inline bool +isAArch64HomogeneousAggregateSmallEnough(std::uint64_t members) { + return members <= 4; +} + +static inline bool isAArch64HomogeneousAggregateBaseType(mlir::Type ty) { + return fir::isa_real(ty); +} + +static bool isAArch64HomogeneousAggregate(mlir::Type ty, mlir::Type &base, + std::uint64_t &members) { + if (ty.isa()) { + auto seqTy = ty.dyn_cast(); + assert(!seqTy.hasDynamicExtents() && + "unexpected dynamic size for interoperable derived type component"); + std::uint64_t nElements = 1; + for (unsigned i = 0; i < seqTy.getDimension(); ++i) + nElements *= seqTy.getShape()[i]; + if (nElements == 0) + return false; + if (!isAArch64HomogeneousAggregate(seqTy.getEleTy(), base, members)) + return false; + members *= nElements; + } else if (auto recTy = ty.dyn_cast_or_null()) { + members = 0; + for (int i = 0; i < (int)recTy.getTypeList().size(); ++i) { + mlir::Type fieldTy = recTy.getTypeList()[i].second; + while (fieldTy.isa()) { + auto seqTy = fieldTy.dyn_cast(); + assert( + !seqTy.hasDynamicExtents() && + "unexpected dynamic size for interoperable derived type component"); + std::uint64_t nElements = 1; + for (unsigned j = 0; j < seqTy.getDimension(); ++j) + nElements *= seqTy.getShape()[j]; + if (nElements == 0) + return false; + fieldTy = seqTy.getEleTy(); + } + + if (auto fieldRecTy = fieldTy.dyn_cast_or_null()) + if (fieldRecTy.getNumFields() == 0) + continue; + + std::uint64_t fieldMembers; + if (!isAArch64HomogeneousAggregate(recTy.getTypeList()[i].second, base, + fieldMembers)) + return false; + + members = members + fieldMembers; + } + if (!base) + return false; + } else { + members = 1; + if (auto cplx = ty.dyn_cast_or_null()) { + members = 2; + ty = cplx.getElementType(); + } + + if (!isAArch64HomogeneousAggregateBaseType(ty)) + return false; + + if (!base) + base = ty; + } + return members > 0 && isAArch64HomogeneousAggregateSmallEnough(members); +} + +static std::uint64_t getScalarSize(mlir::Type scalarTy, std::uint64_t &elems) { + std::uint64_t scalarSize = 0; + if (fir::isa_integer(scalarTy) || fir::isa_real(scalarTy)) { + scalarSize = scalarTy.getIntOrFloatBitWidth(); + } else if (auto cplx = scalarTy.dyn_cast_or_null()) { + scalarSize = 8 * cplx.getFKind(); + elems = 2 * elems; + } else if (auto lgTy = scalarTy.dyn_cast_or_null()) { + assert(lgTy.getFKind() == 1); + scalarSize = 1; + } else if (auto ct = scalarTy.dyn_cast_or_null()) { + // FIXME: gfortran does not support length more than 1, ifort supports. + assert(ct.getLen() == 1 && ct.getFKind() == 1); + scalarSize = 1; + } + return scalarSize; +} + +// FIXME: Replace this with a known target-dependent function. +static std::uint64_t getAArch64DerivedTypeSize(mlir::Type ty, + std::uint64_t &align) { + assert(fir::isa_derived(ty)); + auto recTy = ty.dyn_cast(); + std::uint64_t size = 0; + std::uint64_t lastFieldSize = 0; + std::uint64_t aggregateSize = 0; + + for (int i = 0; i < (int)recTy.getTypeList().size(); ++i) { + mlir::Type fieldTy = recTy.getTypeList()[i].second; + std::uint64_t nElements = 1; + std::uint64_t fieldAlignSize = 0; + if (fir::isa_derived(fieldTy)) { + std::uint64_t dtSize = getAArch64DerivedTypeSize(fieldTy, fieldAlignSize); + nElements = dtSize / fieldAlignSize; + } else if (auto seqTy = fieldTy.dyn_cast_or_null()) { + for (unsigned i = 0; i < seqTy.getDimension(); ++i) + nElements *= seqTy.getShape()[i]; + if (nElements == 0) { + fieldAlignSize = 0; + continue; + } else { + mlir::Type seqEleTy = seqTy.getEleTy(); + if (fir::isa_derived(seqEleTy)) { + std::uint64_t dtSize = + getAArch64DerivedTypeSize(seqEleTy, fieldAlignSize); + nElements = dtSize / fieldAlignSize * nElements; + } else { + fieldAlignSize = getScalarSize(seqEleTy, nElements); + } + } + } else { + fieldAlignSize = getScalarSize(fieldTy, nElements); + } + if (lastFieldSize < fieldAlignSize) + aggregateSize = (aggregateSize + fieldAlignSize - 1) / fieldAlignSize * + fieldAlignSize; + if (fieldAlignSize > align) { + align = fieldAlignSize; + aggregateSize = (aggregateSize + nElements * fieldAlignSize + align - 1) / + align * align; + } else { + aggregateSize += nElements * fieldAlignSize; + } + size = (aggregateSize + align - 1) / align * align; + } + return size; +} + +// FIXME: Replace this with a known target-dependent function. +static std::uint64_t getAArch64UnjustedTypeAlign(mlir::Type ty) { + assert(fir::isa_derived(ty)); + auto recTy = ty.dyn_cast(); + std::uint64_t align = 0; + std::uint64_t fieldAlignSize = 0; + for (int i = 0; i < (int)recTy.getTypeList().size(); ++i) { + mlir::Type fieldTy = recTy.getTypeList()[i].second; + std::uint64_t nElements = 1; + if (fir::isa_derived(fieldTy)) { + (void)getAArch64DerivedTypeSize(fieldTy, fieldAlignSize); + } else if (auto seqTy = fieldTy.dyn_cast_or_null()) { + for (unsigned i = 0; i < seqTy.getDimension(); ++i) + nElements *= seqTy.getShape()[i]; + if (nElements == 0) { + fieldAlignSize = 0; + } else { + mlir::Type seqEleTy = seqTy.getEleTy(); + if (fir::isa_derived(seqEleTy)) + (void)getAArch64DerivedTypeSize(seqEleTy, fieldAlignSize); + else + fieldAlignSize = getScalarSize(seqEleTy, nElements); + } + } else { + fieldAlignSize = getScalarSize(fieldTy, nElements); + } + align = std::max(align, fieldAlignSize); + } + return align; +} + +static mlir::Type +classifyAArch64ArgumentType(Fortran::lower::AbstractConverter &converter, + mlir::Type ty) { + assert(fir::isa_derived(ty)); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + + mlir::Type base; + std::uint64_t members = 0; + if (isAArch64HomogeneousAggregate(ty, base, members)) { + fir::SequenceType::Shape shape(1, members); + return fir::SequenceType::get(shape, base); + } + + std::uint64_t align = 0; + std::uint64_t size = getAArch64DerivedTypeSize(ty, align); + if (size <= 128) { + std::uint64_t align = getAArch64UnjustedTypeAlign(ty); + // FIXME: Use max(align, pointerwidth()) for non-AAPCS. + align = align < 128 ? 64 : 128; + size = (size + align - 1) / align * align; + + mlir::Type baseTy = builder.getIntegerType(align); + if (size == align) + return baseTy; + fir::SequenceType::Shape shape(1, size / align); + return fir::SequenceType::get(shape, baseTy); + } + + return fir::ReferenceType::get(ty); +} + +mlir::Type Fortran::lower::classifyDerivedValueArgumentType( + Fortran::lower::AbstractConverter &converter, mlir::Type ty) { + assert(fir::isa_derived(ty)); + mlir::ModuleOp module = converter.getModuleOp(); + llvm::Triple triple = fir::getTargetTriple(module); + mlir::Location loc = converter.getCurrentLocation(); + + auto recTy = ty.dyn_cast(); + if (recTy.getNumFields() == 0) + TODO(loc, "derived type without component is not interoperable with C"); + + if (triple.isAArch64()) { + if (triple.isOSWindows()) + TODO(loc, "derived type passed by value on aarch64 windows"); + return classifyAArch64ArgumentType(converter, ty); + } + + TODO(loc, "derived type passed by value on non-aarch64"); + return {}; +} Index: flang/lib/Lower/ConvertExpr.cpp =================================================================== --- flang/lib/Lower/ConvertExpr.cpp +++ flang/lib/Lower/ConvertExpr.cpp @@ -2451,6 +2451,18 @@ return res; } + static mlir::Value + genDerivedValueArg(Fortran::lower::AbstractConverter &converter, + mlir::Value val, mlir::Type ty) { + if (ty.isa()) + return val; + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mlir::Location loc = converter.getCurrentLocation(); + mlir::Value cast = + builder.createConvert(loc, fir::ReferenceType::get(ty), val); + return builder.create(loc, cast); + } + /// Given a call site for which the arguments were already lowered, generate /// the call and return the result. This function deals with explicit result /// allocation and lowering if needed. It also deals with passing the host @@ -2659,14 +2671,17 @@ cast = builder.create(loc, boxProcTy, fst); } } else { - if (fir::isa_derived(snd)) { - // FIXME: This seems like a serious bug elsewhere in lowering. Paper - // over the problem for now. - TODO(loc, "derived type argument passed by value"); + mlir::Type fromTy = fst.getType(); + if (fromTy.isa() && + fir::isa_derived(fir::unwrapRefType(fromTy)) && + ((snd.isa() && + fir::isa_derived(fir::unwrapRefType(snd))) || + fir::isa_integer(snd) || snd.isa())) { + cast = genDerivedValueArg(converter, fst, snd); + } else { + cast = builder.convertWithSemantics(loc, snd, fst, + callingImplicitInterface); } - assert(!fir::isa_derived(snd)); - cast = builder.convertWithSemantics(loc, snd, fst, - callingImplicitInterface); } operands.push_back(cast); } Index: flang/test/Lower/structure-pass-by-value.f90 =================================================================== --- /dev/null +++ flang/test/Lower/structure-pass-by-value.f90 @@ -0,0 +1,94 @@ +! Test lowering of derived type passed by value. +! RUN: flang-new -fc1 -emit-fir --target=aarch64-unknown-linux-gnu %s -o - | FileCheck %s +! REQUIRES: shell + +! CHECK-LABEL: func.func @_QPderived_pass_by_value1() { +! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value1Tt1{i:f32,j:i32}> {bindc_name = "my_t1", uniq_name = "_QFderived_pass_by_value1Emy_t1"} +! CHECK: %[[VAL_7:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value1Tt2{t1_obj1:!fir.type<_QFderived_pass_by_value1Tt1{i:f32,j:i32}>,x:i64}> {bindc_name = "my_t2", uniq_name = "_QFderived_pass_by_value1Emy_t2"} +! CHECK: %[[VAL_14:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value1Tt3{t1_obj2:!fir.type<_QFderived_pass_by_value1Tt1{i:f32,j:i32}>,x:!fir.array<2xi64>}> {bindc_name = "my_t3", uniq_name = "_QFderived_pass_by_value1Emy_t3"} +! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref +! CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref +! CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_7]] : (!fir.ref,x:i64}>>) -> !fir.ref> +! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref> +! CHECK: fir.call @c_func_(%[[VAL_22]], %[[VAL_24]], %[[VAL_14]]) : (i64, !fir.array<2xi64>, !fir.ref,x:!fir.array<2xi64>}>>) -> () +! CHECK: return +! CHECK: } + +subroutine derived_pass_by_value1() + type, bind(c) :: t1 + real :: i = -1. + integer :: j = -2 + end type t1 + type, bind(c) :: t2 + type(t1) :: t1_obj1 + integer(8) :: x = 1 + end type t2 + type, bind(c) :: t3 + type(t1) :: t1_obj2 + integer(8) :: x(2) = [1, 2] + end type t3 + type(t1) :: my_t1 + type(t2) :: my_t2 + type(t3) :: my_t3 + + INTERFACE + subroutine c_func(c_t1, c_t2, c_t3) BIND(C, NAME='c_func_') + import :: t1, t2, t3 + type(t1), value :: c_t1 + type(t2), value :: c_t2 + type(t3), value :: c_t3 + END + END INTERFACE + + call c_func(my_t1, my_t2, my_t3) +end + +! CHECK-LABEL: func.func @_QPderived_pass_by_value2() { +! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value2Tt1{i:!fir.array<4xf64>}> {bindc_name = "my_t1", uniq_name = "_QFderived_pass_by_value2Emy_t1"} +! CHECK: %[[VAL_7:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value2Tt2{x:!fir.array<4xi32>}> {bindc_name = "my_t2", uniq_name = "_QFderived_pass_by_value2Emy_t2"} +! CHECK: %[[VAL_14:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value2Tt3{y:!fir.array<2xf32>,z:!fir.complex<4>}> {bindc_name = "my_t3", uniq_name = "_QFderived_pass_by_value2Emy_t3"} +! CHECK: %[[VAL_21:.*]] = fir.alloca !fir.type<_QFderived_pass_by_value2Tt4{w:!fir.array<5xf32>}> {bindc_name = "my_t4", uniq_name = "_QFderived_pass_by_value2Emy_t4"} +! CHECK: %[[VAL_28:.*]] = fir.convert %[[VAL_0]] : (!fir.ref}>>) -> !fir.ref> +! CHECK: %[[VAL_29:.*]] = fir.load %[[VAL_28]] : !fir.ref> +! CHECK: %[[VAL_30:.*]] = fir.convert %[[VAL_7]] : (!fir.ref}>>) -> !fir.ref> +! CHECK: %[[VAL_31:.*]] = fir.load %[[VAL_30]] : !fir.ref> +! CHECK: %[[VAL_32:.*]] = fir.convert %[[VAL_14]] : (!fir.ref,z:!fir.complex<4>}>>) -> !fir.ref> +! CHECK: %[[VAL_33:.*]] = fir.load %[[VAL_32]] : !fir.ref> +! CHECK: fir.call @c_func2_(%[[VAL_29]], %[[VAL_31]], %[[VAL_33]], %[[VAL_21]]) : (!fir.array<4xf64>, !fir.array<2xi64>, !fir.array<4xf32>, !fir.ref}>>) -> () +! CHECK: return +! CHECK: } + +subroutine derived_pass_by_value2() + type, bind(c) :: t1 + real(8) :: i(4) = [-1., -2., -3., -4.] + end type t1 + type, bind(c) :: t2 + integer :: x(4) = [1, 2, 3, 4] + end type t2 + type, bind(c) :: t3 + real :: y(2) = [1., 2.] + complex :: z = (3., 4.) + end type t3 + type, bind(c) :: t4 + real :: w(5) = [1., 2., 3., 4., 5.] + end type t4 + type(t1) :: my_t1 + type(t2) :: my_t2 + type(t3) :: my_t3 + type(t4) :: my_t4 + + INTERFACE + subroutine c_func2(c_t1, c_t2, c_t3, c_t4) BIND(C, NAME='c_func2_') + import :: t1, t2, t3, t4 + type(t1), value :: c_t1 + type(t2), value :: c_t2 + type(t3), value :: c_t3 + type(t4), value :: c_t4 + END + END INTERFACE + + call c_func2(my_t1, my_t2, my_t3, my_t4) +end + +! CHECK: func.func private @c_func_(i64, !fir.array<2xi64>, !fir.ref,x:!fir.array<2xi64>}>>) attributes {fir.bindc_name = "c_func_"} +! CHECK: func.func private @c_func2_(!fir.array<4xf64>, !fir.array<2xi64>, !fir.array<4xf32>, !fir.ref}>>) attributes {fir.bindc_name = "c_func2_"}