diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -204,6 +204,10 @@ virtual void registerRuntimeTypeInfo(mlir::Location loc, SymbolRef typeInfoSym) = 0; + virtual void registerDispatchTableInfo( + mlir::Location loc, + const Fortran::semantics::DerivedTypeSpec *typeSpec) = 0; + //===--------------------------------------------------------------------===// // Locations //===--------------------------------------------------------------------===// diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -212,6 +212,11 @@ bodyBuilder, linkage); } + /// Create a fir::DispatchTable operation. + fir::DispatchTableOp createDispatchTableOp(mlir::Location loc, + llvm::StringRef name, + llvm::StringRef parentName); + /// Convert a StringRef string into a fir::StringLitOp. fir::StringLitOp createStringLitOp(mlir::Location loc, llvm::StringRef string); diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2804,26 +2804,30 @@ ``` }]; + let arguments = (ins + SymbolNameAttr:$sym_name, + OptionalAttr:$parent + ); + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - let regions = (region SizedRegion<1>:$region); + let regions = (region AnyRegion:$region); let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "llvm::StringRef":$name, "mlir::Type":$type, - CArg<"llvm::ArrayRef", "{}">:$attrs), - [{ - $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), - $_builder.getStringAttr(name)); - $_state.addAttributes(attrs); - }]> + "llvm::StringRef":$parent, + CArg<"llvm::ArrayRef", "{}">:$attrs)> ]; let extraClassDeclaration = [{ /// Append a dispatch table entry to the table. void appendTableEntry(mlir::Operation *op); + static constexpr llvm::StringRef getParentAttrNameStr() { return "parent"; } + static constexpr llvm::StringRef getExtendsKeyword() { return "extends"; } + mlir::Block &getBlock() { return getRegion().front(); } diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -45,11 +45,13 @@ #include "flang/Optimizer/Transforms/Passes.h" #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" +#include "flang/Semantics/runtime-type-info.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -192,6 +194,67 @@ llvm::SmallSetVector seen; }; +class DispatchTableConverter { + struct DispatchTableInfo { + const Fortran::semantics::DerivedTypeSpec *typeSpec; + mlir::Location loc; + }; + +public: + void registerTypeSpec(mlir::Location loc, + const Fortran::semantics::DerivedTypeSpec *typeSpec) { + assert(typeSpec && "type spec is null"); + std::string dtName = Fortran::lower::mangle::mangleName(*typeSpec); + if (seen.contains(dtName) || dtName.find("__fortran") != std::string::npos) + return; + seen.insert(dtName); + registeredDispatchTableInfo.emplace_back(DispatchTableInfo{typeSpec, loc}); + } + + void createDispatchTableOps(Fortran::lower::AbstractConverter &converter) { + for (const DispatchTableInfo &info : registeredDispatchTableInfo) { + std::string dtName = Fortran::lower::mangle::mangleName(*info.typeSpec); + const Fortran::semantics::DerivedTypeSpec *parent = + Fortran::evaluate::GetParentTypeSpec(*info.typeSpec); + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + fir::DispatchTableOp dt = builder.createDispatchTableOp( + info.loc, dtName, + parent ? Fortran::lower::mangle::mangleName(*parent) : ""); + auto insertPt = builder.saveInsertionPoint(); + + std::vector bindings = + Fortran::semantics::CollectBindings(*info.typeSpec->scope()); + + if (!bindings.empty()) + builder.createBlock(&dt.getRegion()); + + for (const Fortran::semantics::Symbol *binding : bindings) { + const auto *details = + binding->detailsIf(); + std::string bindingName = + Fortran::lower::mangle::mangleName(details->symbol()); + builder.create( + info.loc, + mlir::StringAttr::get(builder.getContext(), + binding->name().ToString()), + mlir::SymbolRefAttr::get(builder.getContext(), bindingName)); + } + if (!bindings.empty()) + builder.create(info.loc); + builder.restoreInsertionPoint(insertPt); + } + registeredDispatchTableInfo.clear(); + } + +private: + /// Store the semantic DerivedTypeSpec that will be required to generate the + /// dispatch table. + llvm::SmallVector registeredDispatchTableInfo; + + /// Track processed type specs to avoid multiple creation. + llvm::StringSet<> seen; +}; + using IncrementLoopNestInfo = llvm::SmallVector; } // namespace @@ -269,6 +332,10 @@ createGlobalOutsideOfFunctionLowering( [&]() { runtimeTypeInfoConverter.createTypeInfoGlobals(*this); }); + /// Create the dispatch tables for derived types. + createGlobalOutsideOfFunctionLowering( + [&]() { dispatchTableConverter.createDispatchTableOps(*this); }); + // Create the list of any environment defaults for the runtime to set. The // runtime default list is only created if there is a main program to ensure // it only happens once and to provide consistent results if multiple files @@ -744,6 +811,12 @@ runtimeTypeInfoConverter.registerTypeInfoSymbol(*this, loc, typeInfoSym); } + void registerDispatchTableInfo( + mlir::Location loc, + const Fortran::semantics::DerivedTypeSpec *typeSpec) override final { + dispatchTableConverter.registerTypeSpec(loc, typeSpec); + } + private: FirConverter() = delete; FirConverter(const FirConverter &) = delete; @@ -3553,6 +3626,7 @@ Fortran::lower::SymMap localSymbols; Fortran::parser::CharBlock currentPosition; RuntimeTypeInfoConverter runtimeTypeInfoConverter; + DispatchTableConverter dispatchTableConverter; /// WHERE statement/construct mask expression stack. Fortran::lower::ImplicitIterSpace implicitIterSpace; diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp --- a/flang/lib/Lower/ConvertType.cpp +++ b/flang/lib/Lower/ConvertType.cpp @@ -340,6 +340,8 @@ } LLVM_DEBUG(llvm::dbgs() << "derived type: " << rec << '\n'); + converter.registerDispatchTableInfo(loc, &tySpec); + // Generate the type descriptor object if any if (const Fortran::semantics::Scope *derivedScope = tySpec.scope() ? tySpec.scope() : tySpec.typeSymbol().scope()) diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -271,6 +271,18 @@ return glob; } +fir::DispatchTableOp fir::FirOpBuilder::createDispatchTableOp( + mlir::Location loc, llvm::StringRef name, llvm::StringRef parentName) { + auto module = getModule(); + auto insertPt = saveInsertionPoint(); + if (auto dt = module.lookupSymbol(name)) + return dt; + setInsertionPoint(module.getBody(), module.getBody()->end()); + auto dt = create(loc, name, mlir::Type{}, parentName); + restoreInsertionPoint(insertPt); + return dt; +} + mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc, mlir::Type toTy, mlir::Value val, diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1051,30 +1051,30 @@ } }; -/// Lower `fir.dispatch_table` operation. The dispatch table for a Fortran -/// derived type. +/// `fir.disptach_table` operation has no specific CodeGen. The operation is +/// only used to carry information during FIR to FIR passes. struct DispatchTableOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult - matchAndRewrite(fir::DispatchTableOp dispTab, OpAdaptor adaptor, + matchAndRewrite(fir::DispatchTableOp op, OpAdaptor, mlir::ConversionPatternRewriter &rewriter) const override { - TODO(dispTab.getLoc(), "fir.dispatch_table codegen"); - return mlir::failure(); + rewriter.eraseOp(op); + return mlir::success(); } }; -/// Lower `fir.dt_entry` operation. An entry in a dispatch table; binds a -/// method-name to a function. +/// `fir.dt_entry` operation has no specific CodeGen. The operation is only used +/// to carry information during FIR to FIR passes. struct DTEntryOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; mlir::LogicalResult - matchAndRewrite(fir::DTEntryOp dtEnt, OpAdaptor adaptor, + matchAndRewrite(fir::DTEntryOp op, OpAdaptor, mlir::ConversionPatternRewriter &rewriter) const override { - TODO(dtEnt.getLoc(), "fir.dt_entry codegen"); - return mlir::failure(); + rewriter.eraseOp(op); + return mlir::success(); } }; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1092,14 +1092,19 @@ mlir::ParseResult fir::DispatchTableOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { // Parse the name as a symbol reference attribute. - mlir::SymbolRefAttr nameAttr; - if (parser.parseAttribute(nameAttr, mlir::SymbolTable::getSymbolAttrName(), - result.attributes)) + mlir::StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) return mlir::failure(); - // Convert the parsed name attr into a string attr. - result.attributes.set(mlir::SymbolTable::getSymbolAttrName(), - nameAttr.getRootReference()); + if (!failed(parser.parseOptionalKeyword(getExtendsKeyword()))) { + mlir::StringAttr parent; + if (parser.parseLParen() || + parser.parseAttribute(parent, getParentAttrNameStr(), + result.attributes) || + parser.parseRParen()) + return mlir::failure(); + } // Parse the optional table body. mlir::Region *body = result.addRegion(); @@ -1113,11 +1118,11 @@ } void fir::DispatchTableOp::print(mlir::OpAsmPrinter &p) { - auto tableName = getOperation() - ->getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()) - .getValue(); - p << " @" << tableName; + p << ' '; + p.printSymbolName(getSymName()); + if (getParent()) + p << ' ' << getExtendsKeyword() << '(' + << (*this)->getAttr(getParentAttrNameStr()) << ')'; mlir::Region &body = getOperation()->getRegion(0); if (!body.empty()) { @@ -1128,12 +1133,29 @@ } mlir::LogicalResult fir::DispatchTableOp::verify() { + if (getRegion().empty()) + return mlir::success(); for (auto &op : getBlock()) if (!mlir::isa(op)) return op.emitOpError("dispatch table must contain dt_entry"); return mlir::success(); } +void fir::DispatchTableOp::build(mlir::OpBuilder &builder, + mlir::OperationState &result, + llvm::StringRef name, mlir::Type type, + llvm::StringRef parent, + llvm::ArrayRef attrs) { + result.addRegion(); + result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + if (!parent.empty()) + result.addAttribute(getParentAttrNameStr(), builder.getStringAttr(parent)); + // result.addAttribute(getSymbolAttrNameStr(), + // mlir::SymbolRefAttr::get(builder.getContext(), name)); + result.addAttributes(attrs); +} + //===----------------------------------------------------------------------===// // EmboxOp //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/Todo/dispatch_table.fir b/flang/test/Fir/Todo/dispatch_table.fir deleted file mode 100644 --- a/flang/test/Fir/Todo/dispatch_table.fir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: %not_todo_cmd fir-opt --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" %s 2>&1 | FileCheck %s - -// Test fir.dispatch_table conversion to llvm. -// Not implemented yet. - -// CHECK: not yet implemented: fir.dispatch_table codegen -fir.dispatch_table @dispatch_tbl { - fir.dt_entry "method", @method_impl -} diff --git a/flang/test/Lower/dispatch-table.f90 b/flang/test/Lower/dispatch-table.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Lower/dispatch-table.f90 @@ -0,0 +1,75 @@ +! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s + +! Tests the generation of fir.dispatch_table operations. + +module polymorphic_types + type p1 + integer :: a + integer :: b + contains + procedure :: proc1 => proc1_p1 + procedure :: aproc + procedure :: zproc + end type + + type, extends(p1) :: p2 + integer :: c + contains + procedure :: proc1 => proc1_p2 + procedure :: aproc2 + end type + + type, extends(p2) :: p3 + integer :: d + contains + procedure :: aproc3 + end type +contains + + + subroutine proc1_p1(p) + class(p1) :: p + end subroutine + + subroutine aproc(p) + class(p1) :: p + end subroutine + + subroutine zproc(p) + class(p1) :: p + end subroutine + + subroutine proc1_p2(p) + class(p2) :: p + end subroutine + + subroutine aproc2(p) + class(p2) :: p + end subroutine + + subroutine aproc3(p) + class(p3) :: p + end subroutine + +end module + +! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp1 { +! CHECK: fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc +! CHECK: fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p1 +! CHECK: fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc +! CHECK: } + +! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp2 extends("_QMpolymorphic_typesTp1") { +! CHECK: fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc +! CHECK: fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p2 +! CHECK: fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc +! CHECK: fir.dt_entry "aproc2", @_QMpolymorphic_typesPaproc2 +! CHECK: } + +! CHECK-LABEL: fir.dispatch_table @_QMpolymorphic_typesTp3 extends("_QMpolymorphic_typesTp2") { +! CHECK: fir.dt_entry "aproc", @_QMpolymorphic_typesPaproc +! CHECK: fir.dt_entry "proc1", @_QMpolymorphic_typesPproc1_p2 +! CHECK: fir.dt_entry "zproc", @_QMpolymorphic_typesPzproc +! CHECK: fir.dt_entry "aproc2", @_QMpolymorphic_typesPaproc2 +! CHECK: fir.dt_entry "aproc3", @_QMpolymorphic_typesPaproc3 +! CHECK: } diff --git a/flang/test/Lower/polymorphic-types.f90 b/flang/test/Lower/polymorphic-types.f90 --- a/flang/test/Lower/polymorphic-types.f90 +++ b/flang/test/Lower/polymorphic-types.f90 @@ -180,4 +180,5 @@ ! CHECK-LABEL: func.func @assumed_type_dummy_array( ! CHECK-SAME: %{{.*}}: !fir.box> + end module