diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -46,6 +46,9 @@ // fir::LLVMTypeConverter for converting to LLVM IR dialect types. #include "TypeConverter.h" +using BindingTable = llvm::DenseMap; +using BindingTables = llvm::DenseMap; + // TODO: This should really be recovered from the specified target. static constexpr unsigned defaultAlign = 8; @@ -93,8 +96,10 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { public: explicit FIROpConversion(fir::LLVMTypeConverter &lowering, - const fir::FIRToLLVMPassOptions &options) - : mlir::ConvertOpToLLVMPattern(lowering), options(options) {} + const fir::FIRToLLVMPassOptions &options, + const BindingTables &bindingTables) + : mlir::ConvertOpToLLVMPattern(lowering), options(options), + bindingTables(bindingTables) {} protected: mlir::Type convertType(mlir::Type ty) const { @@ -293,6 +298,7 @@ } const fir::FIRToLLVMPassOptions &options; + const BindingTables &bindingTables; }; /// FIR conversion pattern template @@ -3293,8 +3299,9 @@ template struct MustBeDeadConversion : public FIROpConversion { explicit MustBeDeadConversion(fir::LLVMTypeConverter &lowering, - const fir::FIRToLLVMPassOptions &options) - : FIROpConversion(lowering, options) {} + const fir::FIRToLLVMPassOptions &options, + const BindingTables &bindingTables) + : FIROpConversion(lowering, options, bindingTables) {} using OpAdaptor = typename FromOp::Adaptor; mlir::LogicalResult @@ -3354,6 +3361,32 @@ if (mlir::failed(runPipeline(mathConvertionPM, mod))) return signalPassFailure(); + // Reconstruct binding tables for dynamic dispatch. The binding tables + // are defined in FIR from semantics as fir.global operation with region + // initializer. Go through each bining tables and store the procedure name + // and binding index for later use by the fir.dispatch conversion pattern. + BindingTables bindingTables; + for (auto globalOp : mod.getOps()) { + if (globalOp.getSymName().contains(".v.")) { + unsigned bindingIdx = 0; + BindingTable bindings; + for (auto addrOp : globalOp.getRegion().getOps()) { + if (fir::isa_char(fir::unwrapRefType(addrOp.getType()))) { + if (auto nameGlobal = + mod.lookupSymbol(addrOp.getSymbol())) { + auto stringLit = llvm::to_vector( + nameGlobal.getRegion().getOps())[0]; + auto procName = + stringLit.getValue().dyn_cast().getValue(); + bindings[procName] = bindingIdx; + ++bindingIdx; + } + } + } + bindingTables[globalOp.getSymName()] = bindings; + } + } + auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::RewritePatternSet pattern(context); @@ -3378,8 +3411,8 @@ SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion, - XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(typeConverter, - options); + XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>( + typeConverter, options, bindingTables); mlir::populateFuncToLLVMConversionPatterns(typeConverter, pattern); mlir::populateOpenMPToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);