diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -261,6 +261,24 @@ } } +/// Unwrap the referential and sequential outer types (if any). Returns the +/// the element if type is fir::RecordType +inline fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) { + return fir::unwrapSequenceType(fir::unwrapRefType(boxTy.getEleTy())) + .template dyn_cast(); +} + +/// Return true iff `boxTy` wraps a fir::RecordType with length parameters +inline bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) { + auto recTy = unwrapIfDerived(boxTy); + return recTy && recTy.getNumLenParams() > 0; +} + +/// Return true iff `boxTy` wraps a fir::RecordType +inline bool isDerivedType(fir::BaseBoxType boxTy) { + return static_cast(unwrapIfDerived(boxTy)); +} + #ifndef NDEBUG // !fir.ptr and !fir.heap where X is !fir.ptr, !fir.heap, or !fir.ref // is undefined and disallowed. @@ -300,6 +318,13 @@ /// value. bool isUnlimitedPolymorphicType(mlir::Type ty); +/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic +/// entity. Polymorphic entities with intrinsic type spec do not have addendum +inline bool boxHasAddendum(fir::BaseBoxType boxTy) { + return static_cast(unwrapIfDerived(boxTy)) || + fir::isUnlimitedPolymorphicType(boxTy); +} + /// Return the inner type of the given type. mlir::Type unwrapInnerType(mlir::Type ty); diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h --- a/flang/include/flang/Optimizer/Support/Utils.h +++ b/flang/include/flang/Optimizer/Support/Utils.h @@ -15,8 +15,12 @@ #include "flang/Common/default-kinds.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" namespace fir { /// Return the integer value of a arith::ConstantOp. @@ -24,6 +28,11 @@ return cop.getValue().cast().getValue().getSExtValue(); } +// Reconstruct binding tables for dynamic dispatch. +using BindingTable = llvm::DenseMap; +using BindingTables = llvm::DenseMap; +void buildBindingTables(BindingTables &, mlir::ModuleOp mod); + // Translate front-end KINDs for use in the IR and code gen. inline std::vector fromDefaultKinds(const Fortran::common::IntrinsicTypeDefaultKinds &defKinds) { diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -16,8 +16,10 @@ #include "flang/ISO_Fortran_binding.h" #include "flang/Optimizer/Dialect/FIRAttr.h" #include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Support/InternalNames.h" #include "flang/Optimizer/Support/TypeCode.h" +#include "flang/Optimizer/Support/Utils.h" #include "flang/Semantics/runtime-type-info.h" #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -50,9 +52,6 @@ // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" -using BindingTable = llvm::DenseMap; -using BindingTables = llvm::DenseMap; - // TODO: This should really be recovered from the specified target. static constexpr unsigned defaultAlign = 8; @@ -106,7 +105,7 @@ public: explicit FIROpConversion(fir::LLVMTypeConverter &lowering, const fir::FIRToLLVMPassOptions &options, - const BindingTables &bindingTables) + const fir::BindingTables &bindingTables) : mlir::ConvertOpToLLVMPattern(lowering), options(options), bindingTables(bindingTables) {} @@ -359,7 +358,7 @@ } const fir::FIRToLLVMPassOptions &options; - const BindingTables &bindingTables; + const fir::BindingTables &bindingTables; }; /// FIR conversion pattern template @@ -993,7 +992,7 @@ << "cannot find binding table for " << recordType.getName(); // Lookup for the binding. - const BindingTable &bindingTable = bindingsIter->second; + const fir::BindingTable &bindingTable = bindingsIter->second; auto bindingIter = bindingTable.find(dispatch.getMethod()); if (bindingIter == bindingTable.end()) return emitError(loc) @@ -1336,22 +1335,6 @@ return CFI_attribute_other; } - static fir::RecordType unwrapIfDerived(fir::BaseBoxType boxTy) { - return fir::unwrapSequenceType(fir::dyn_cast_ptrOrBoxEleTy(boxTy)) - .template dyn_cast(); - } - static bool isDerivedTypeWithLenParams(fir::BaseBoxType boxTy) { - auto recTy = unwrapIfDerived(boxTy); - return recTy && recTy.getNumLenParams() > 0; - } - static bool isDerivedType(fir::BaseBoxType boxTy) { - return static_cast(unwrapIfDerived(boxTy)); - } - static bool hasAddendum(fir::BaseBoxType boxTy) { - return static_cast(unwrapIfDerived(boxTy)) || - fir::isUnlimitedPolymorphicType(boxTy); - } - // Get the element size and CFI type code of the boxed value. std::tuple getSizeAndTypeCode( mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, @@ -1571,7 +1554,7 @@ descriptor = insertField(rewriter, loc, descriptor, {kAttributePosInBox}, this->genI32Constant(loc, rewriter, getCFIAttr(boxTy))); - const bool hasAddendum = isDerivedType(boxTy) || isUnlimitedPolymorphic; + const bool hasAddendum = fir::boxHasAddendum(boxTy); descriptor = insertField(rewriter, loc, descriptor, {kF18AddendumPosInBox}, this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0)); @@ -1591,8 +1574,8 @@ loc, ::getVoidPtrType(mod.getContext())); } } else { - typeDesc = - getTypeDescriptor(mod, rewriter, loc, unwrapIfDerived(boxTy)); + typeDesc = getTypeDescriptor(mod, rewriter, loc, + fir::unwrapIfDerived(boxTy)); } } if (typeDesc) @@ -1674,7 +1657,7 @@ // TODO: For initial box that are unlimited polymorphic entities, this // code must be made conditional because unlimited polymorphic entities // with intrinsic type spec does not have addendum. - if (hasAddendum(inputBoxTy)) + if (fir::boxHasAddendum(inputBoxTy)) typeDesc = this->loadTypeDescAddress(loc, box.getBox().getType(), loweredBox, rewriter); } @@ -1826,7 +1809,7 @@ /*rank=*/0, /*lenParams=*/operands.drop_front(1), sourceBox, sourceBoxType); dest = insertBaseAddress(rewriter, embox.getLoc(), dest, operands[0]); - if (isDerivedTypeWithLenParams(boxTy)) { + if (fir::isDerivedTypeWithLenParams(boxTy)) { TODO(embox.getLoc(), "fir.embox codegen of derived with length parameters"); return mlir::failure(); @@ -2010,7 +1993,7 @@ fieldIndices, substringOffset); } dest = insertBaseAddress(rewriter, loc, dest, base); - if (isDerivedTypeWithLenParams(boxTy)) + if (fir::isDerivedTypeWithLenParams(boxTy)) TODO(loc, "fir.embox codegen of derived with length parameters"); mlir::Value result = @@ -3670,7 +3653,7 @@ struct MustBeDeadConversion : public FIROpConversion { explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering, const fir::FIRToLLVMPassOptions &options, - const BindingTables &bindingTables) + const fir::BindingTables &bindingTables) : FIROpConversion(lowering, options, bindingTables) {} using OpAdaptor = typename FromOp::Adaptor; @@ -3781,24 +3764,8 @@ if (mlir::failed(runPipeline(mathConvertionPM, mod))) return signalPassFailure(); - // Reconstruct binding tables for dynamic dispatch. The binding tables - // are defined in FIR from lowering as fir.dispatch_table operation. - // Go through each binding tables and store the procedure name - // and binding index for later use by the fir.dispatch conversion pattern. - BindingTables bindingTables; - for (auto dispatchTableOp : mod.getOps()) { - unsigned bindingIdx = 0; - BindingTable bindings; - if (dispatchTableOp.getRegion().empty()) { - bindingTables[dispatchTableOp.getSymName()] = bindings; - continue; - } - for (auto dtEntry : dispatchTableOp.getBlock().getOps()) { - bindings[dtEntry.getMethod()] = bindingIdx; - ++bindingIdx; - } - bindingTables[dispatchTableOp.getSymName()] = bindings; - } + fir::BindingTables bindingTables; + fir::buildBindingTables(bindingTables, mod); auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule(), diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -219,12 +219,8 @@ return llvm::TypeSwitch(t) .Case([](auto p) { return p.getEleTy(); }) - .Case([](auto p) { - auto eleTy = p.getEleTy(); - if (auto ty = fir::dyn_cast_ptrEleTy(eleTy)) - return ty; - return eleTy; - }) + .Case( + [](auto p) { return unwrapRefType(p.getEleTy()); }) .Default([](mlir::Type) { return mlir::Type{}; }); } diff --git a/flang/lib/Optimizer/Support/CMakeLists.txt b/flang/lib/Optimizer/Support/CMakeLists.txt --- a/flang/lib/Optimizer/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/Support/CMakeLists.txt @@ -5,6 +5,7 @@ InitFIR.cpp InternalNames.cpp KindMapping.cpp + Utils.cpp DEPENDS FIROpsIncGen diff --git a/flang/lib/Optimizer/Support/Utils.cpp b/flang/lib/Optimizer/Support/Utils.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Support/Utils.cpp @@ -0,0 +1,36 @@ +//===-- Utils.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Support/Utils.h" +#include "flang/Optimizer/Dialect/FIROps.h" + +namespace fir { +void buildBindingTables(BindingTables &bindingTables, mlir::ModuleOp mod) { + + // The binding tables are defined in FIR from lowering as fir.dispatch_table + // operation. Go through each binding tables and store the procedure name and + // binding index for later use by the fir.dispatch conversion pattern. + for (auto dispatchTableOp : mod.getOps()) { + unsigned bindingIdx = 0; + BindingTable bindings; + if (dispatchTableOp.getRegion().empty()) { + bindingTables[dispatchTableOp.getSymName()] = bindings; + continue; + } + for (auto dtEntry : dispatchTableOp.getBlock().getOps()) { + bindings[dtEntry.getMethod()] = bindingIdx; + ++bindingIdx; + } + bindingTables[dispatchTableOp.getSymName()] = bindings; + } +} +} // namespace fir