Index: flang/include/flang/Optimizer/CodeGen/CGOps.td =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -0,0 +1,118 @@ +//===-- CGOps.td - FIR operation definitions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Definition of the FIRCG dialect operations +/// +//===----------------------------------------------------------------------===// + +#ifndef FIR_DIALECT_CG_OPS +#define FIR_DIALECT_CG_OPS + +include "mlir/IR/SymbolInterfaces.td" +include "flang/Optimizer/TypePredicates.td" + +def fircg_Dialect : Dialect { + let name = "fircg"; + let cppNamespace = "::fir::cg"; +} + +// Base class for FIR CG operations. +// All operations automatically get a prefix of "fircg.". +class fircg_Op traits> + : Op; + +// Extended embox operation. +def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference being emboxed. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices for subobject slicing. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref and shape arguments are mandatory. The rest are optional. + }]; + + let arguments = (ins + AnyReferenceLike:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$lenParams + ); + let results = (outs fir_BoxType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? (`path` $subcomponent^)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + unsigned getRank() { return shape().size(); } + unsigned lenParamOffset() { + return 1 + shape().size() + shift().size() + slice().size() + + subcomponent().size(); + } + }]; +} + +// Extended array coordinate operation. +def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference of the array's data. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices that describe subobject slicing. + - indices: A vector of runtime values that describe the coordinate of + the element of the array to be computed. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref, shape, and indices arguments are mandatory. The rest are + optional. + }]; + + let arguments = (ins + AnyReferenceLike:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$indices, + Variadic:$lenParams + ); + let results = (outs fir_ReferenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? (`path` $subcomponent^)? `<`$indices`>` (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + unsigned getRank() { return shape().size(); } + }]; +} + +#endif Index: flang/include/flang/Optimizer/CodeGen/CGPasses.td =================================================================== --- flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -16,9 +16,30 @@ include "mlir/Pass/PassBase.td" -def CodeGenRewrite : FunctionPass<"cg-rewrite"> { +def CodeGenRewrite : Pass<"cg-rewrite", "mlir::ModuleOp"> { let summary = "Rewrite some FIR ops into their code-gen forms."; + let description = [{ + Fuse specific subgraphs into single Ops for code generation. + }]; let constructor = "fir::createFirCodeGenRewritePass()"; + let dependentDialects = ["fir::FIROpsDialect"]; +} + +def TargetRewrite : Pass<"target-rewrite", "mlir::ModuleOp"> { + let summary = "Rewrite some FIR dialect into target specific forms. " + "Certain abstractions in the FIR dialect need to be rewritten " + "to reflect representations that may differ based on the " + "target machine."; + let constructor = "fir::createFirTargetRewritePass()"; + let dependentDialects = ["fir::FIROpsDialect"]; + let options = [ + Option<"noCharacterConversion", "no-character-conversion", + "bool", /*default=*/"false", + "Disable target-specific conversion of CHARACTER.">, + Option<"noComplexConversion", "no-complex-conversion", + "bool", /*default=*/"false", + "Disable target-specific conversion of COMPLEX."> + ]; } #endif // FLANG_OPTIMIZER_CODEGEN_PASSES Index: flang/include/flang/Optimizer/CodeGen/CMakeLists.txt =================================================================== --- flang/include/flang/Optimizer/CodeGen/CMakeLists.txt +++ flang/include/flang/Optimizer/CodeGen/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS CGOps.td) +mlir_tablegen(CGOps.h.inc -gen-op-decls) +mlir_tablegen(CGOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(CGOpsIncGen) set(LLVM_TARGET_DEFINITIONS CGPasses.td) mlir_tablegen(CGPasses.h.inc -gen-pass-decls -name OptCodeGen) Index: flang/include/flang/Optimizer/CodeGen/CodeGen.h =================================================================== --- flang/include/flang/Optimizer/CodeGen/CodeGen.h +++ flang/include/flang/Optimizer/CodeGen/CodeGen.h @@ -9,6 +9,7 @@ #ifndef OPTIMIZER_CODEGEN_CODEGEN_H #define OPTIMIZER_CODEGEN_CODEGEN_H +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include @@ -21,6 +22,17 @@ /// the code gen (to LLVM-IR dialect) conversion. std::unique_ptr createFirCodeGenRewritePass(); +/// FirTargetRewritePass options. +struct TargetRewriteOptions { + bool noCharacterConversion{}; + bool noComplexConversion{}; +}; + +/// Prerequiste pass for code gen. Perform intermediate rewrites to tailor the +/// IR for the chosen target. +std::unique_ptr> createFirTargetRewritePass( + const TargetRewriteOptions &options = TargetRewriteOptions()); + /// Convert FIR to the LLVM IR dialect std::unique_ptr createFIRToLLVMPass(NameUniquer &uniquer); Index: flang/include/flang/Optimizer/Transforms/CMakeLists.txt =================================================================== --- flang/include/flang/Optimizer/Transforms/CMakeLists.txt +++ flang/include/flang/Optimizer/Transforms/CMakeLists.txt @@ -1,4 +1,9 @@ + +set(LLVM_TARGET_DEFINITIONS RewritePatterns.td) +mlir_tablegen(RewritePatterns.inc -gen-rewriters) +add_public_tablegen_target(RewritePatternsIncGen) + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name OptTransform) add_public_tablegen_target(FIROptTransformsPassIncGen) Index: flang/include/flang/Optimizer/Transforms/Passes.h =================================================================== --- flang/include/flang/Optimizer/Transforms/Passes.h +++ flang/include/flang/Optimizer/Transforms/Passes.h @@ -16,39 +16,38 @@ namespace mlir { class BlockAndValueMapping; class Operation; -class Pass; class Region; } // namespace mlir namespace fir { -/// Convert fir.select_type to the standard dialect -std::unique_ptr createControlFlowLoweringPass(); +//===----------------------------------------------------------------------===// +// Passes defined in Passes.td +//===----------------------------------------------------------------------===// -/// Effects aware CSE pass +std::unique_ptr createControlFlowLoweringPass(); std::unique_ptr createCSEPass(); - -/// Convert FIR loop constructs to the Affine dialect std::unique_ptr createPromoteToAffinePass(); - -/// Convert `fir.do_loop` and `fir.if` to a CFG. This -/// conversion enables the `createLowerToCFGPass` to transform these to CFG -/// form. +std::unique_ptr createAffineDemotionPass(); +std::unique_ptr createFirLoopResultOptPass(); +std::unique_ptr createMemDataFlowOptPass(); std::unique_ptr createFirToCfgPass(); +std::unique_ptr createArrayValueCopyPass(); -/// A pass to convert the FIR dialect from "Mem-SSA" form to "Reg-SSA" -/// form. This pass is a port of LLVM's mem2reg pass, but modified for the FIR -/// dialect as well as the restructuring of MLIR's representation to present PHI -/// nodes as block arguments. +/// A pass to convert the FIR dialect from "Mem-SSA" form to "Reg-SSA" form. +/// This pass is a port of LLVM's mem2reg pass, but modified for the FIR dialect +/// as well as the restructuring of MLIR's representation to present PHI nodes +/// as block arguments. +/// TODO: This pass needs some additional work. std::unique_ptr createMemToRegPass(); /// Support for inlining on FIR. -bool canLegallyInline(mlir::Operation *op, mlir::Region *reg, +bool canLegallyInline(mlir::Operation *op, mlir::Region *reg, bool, mlir::BlockAndValueMapping &map); +bool canLegallyInline(mlir::Operation *, mlir::Operation *, bool); -// declarative passes -#define GEN_PASS_REGISTRATION -#include "flang/Optimizer/Transforms/Passes.h.inc" +/// Optionally force the body of a DO to execute at least once. +bool isAlwaysExecuteLoopBody(); } // namespace fir Index: flang/include/flang/Optimizer/Transforms/Passes.td =================================================================== --- flang/include/flang/Optimizer/Transforms/Passes.td +++ flang/include/flang/Optimizer/Transforms/Passes.td @@ -17,35 +17,105 @@ include "mlir/Pass/PassBase.td" def AffineDialectPromotion : FunctionPass<"promote-to-affine"> { - let summary = "Promotes fir.loop and fir.where to affine.for and affine.if where possible"; + let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`."; let description = [{ - TODO + Convert fir operations which satisfy affine constraints to the affine + dialect. + + `fir.do_loop` will be converted to `affine.for` if the loops inside the body + can be converted and the indices for memory loads and stores satisfy + `affine.apply` criteria for symbols and dimensions. + + `fir.if` will be converted to `affine.if` where possible. `affine.if`'s + condition uses an integer set (==, >=) and an analysis is done to determine + the fir condition's parent operations to construct the integer set. + + `fir.load` (`fir.store`) will be converted to `affine.load` (`affine.store`) + where possible. This conversion includes adding a dummy `fir.convert` cast + to adapt values of type `!fir.ref` to `memref`. This is done + because the affine dialect presently only understands the `memref` type. + }]; + let constructor = "::fir::createPromoteToAffinePass()"; +} + +def AffineDialectDemotion : FunctionPass<"demote-affine"> { + let summary = "Converts `affine.{load,store}` back to fir operations"; + let description = [{ + Affine dialect's default lowering for loads and stores is different from + fir as it uses the `memref` type. The `memref` type is not compatible with + the Fortran runtime. Therefore, conversion of memory operations back to + `fir.load` and `fir.store` with `!fir.ref` types is required. + }]; + let constructor = "::fir::createAffineDemotionPass()"; +} + +def FirLoopResultOpt : FunctionPass<"fir-loop-result-opt"> { + let summary = "Optimizes `fir.do_loop` by removing unused final iteration values."; + let description = [{ + TODO - do we need this if we overhaul fir.do_loop a bit? }]; - let constructor = "fir::createPromoteToAffinePass()"; + let constructor = "::fir::createFirLoopResultOptPass()"; +} + +def MemRefDataFlowOpt : FunctionPass<"fir-memref-dataflow-opt"> { + let summary = "Perform store/load forwarding and potentially removing dead stores."; + let description = [{ + This pass performs store to load forwarding to eliminate memory accesses and + potentially the entire allocation if all the accesses are forwarded. + }]; + let constructor = "::fir::createMemDataFlowOptPass()"; } def BasicCSE : FunctionPass<"basic-cse"> { - let summary = "Basic common sub-expression elimination"; + let summary = "Basic common sub-expression elimination."; let description = [{ - TODO + Perform common subexpression elimination on FIR operations. This pass + differs from the MLIR CSE pass in that it is FIR/Fortran semantics aware. }]; - let constructor = "fir::createCSEPass()"; + let constructor = "::fir::createCSEPass()"; } def ControlFlowLowering : FunctionPass<"lower-control-flow"> { - let summary = "Convert affine dialect, fir.select_type to standard dialect"; + let summary = "Convert affine dialect, fir.select_type to standard dialect."; let description = [{ - TODO + This converts the affine dialect back to standard dialect. It also converts + `fir.select_type` to more primitive operations. This pass is required before + code gen to the LLVM IR dialect. + + TODO: Should the affine rewriting by moved to AffineDialectDemotion? }]; - let constructor = "fir::createControlFlowLoweringPass()"; + let constructor = "::fir::createControlFlowLoweringPass()"; } def CFGConversion : FunctionPass<"cfg-conversion"> { let summary = "Convert FIR structured control flow ops to CFG ops."; let description = [{ - TODO + Transform the `fir.do_loop`, `fir.if`, and `fir.iterate_while` ops into + plain old test and branch operations. Removing the high-level control + structures can enable other optimizations. + + This pass is required before code gen to the LLVM IR dialect. + }]; + let constructor = "::fir::createFirToCfgPass()"; +} + +def ArrayValueCopy : FunctionPass<"array-value-copy"> { + let summary = "Convert array value operations to memory operations."; + let description = [{ + Transform the set of array value primitives to a memory-based array + representation. + + The Ops `array_load`, `array_store`, `array_fetch`, and `array_update` are + used to manage abstract aggregate array values. A simple analysis is done + to determine if there are potential dependences between these operations. + If not, these array operations can be lowered to work directly on the memory + representation. If there is a potential conflict, a temporary is created + along with appropriate copy-in/copy-out operations. Here, a more refined + analysis might be deployed, such as using the affine framework. + + This pass is required before code gen to the LLVM IR dialect. }]; - let constructor = "fir::createFirToCfgPass()"; + let constructor = "::fir::createArrayValueCopyPass()"; } #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES Index: flang/include/flang/Optimizer/Transforms/RewritePatterns.td =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/Transforms/RewritePatterns.td @@ -0,0 +1,59 @@ +//===-- RewritePatterns.td - FIR Rewrite Patterns -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Defines pattern rewrites for fir optimizations +/// +//===----------------------------------------------------------------------===// + +#ifndef FIR_REWRITE_PATTERNS +#define FIR_REWRITE_PATTERNS + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "flang/Optimizer/Dialect/FIROps.td" + +def IdenticalTypePred : Constraint>; +def IntegerTypePred : Constraint>; +def IndexTypePred : Constraint()">>; + +def SmallerWidthPred + : Constraint>; + +def ConvertConvertOptPattern + : Pat<(fir_ConvertOp (fir_ConvertOp $arg)), + (fir_ConvertOp $arg), + [(IntegerTypePred $arg)]>; + +def RedundantConvertOptPattern + : Pat<(fir_ConvertOp:$res $arg), + (replaceWithValue $arg), + [(IdenticalTypePred $res, $arg) + ,(IntegerTypePred $arg)]>; + +def CombineConvertOptPattern + : Pat<(fir_ConvertOp:$res(fir_ConvertOp:$irm $arg)), + (replaceWithValue $arg), + [(IdenticalTypePred $res, $arg) + ,(IntegerTypePred $arg) + ,(IntegerTypePred $irm) + ,(SmallerWidthPred $arg, $irm)]>; + +def createConstantOp + : NativeCodeCall<"$_builder.create" + "($_loc, $_builder.getIndexType(), " + "rewriter.getIndexAttr($1.dyn_cast().getInt()))">; + +def ForwardConstantConvertPattern + : Pat<(fir_ConvertOp:$res (ConstantOp:$cnt $attr)), + (createConstantOp $res, $attr), + [(IndexTypePred $res) + ,(IntegerTypePred $cnt)]>; + +#endif // FIR_REWRITE_PATTERNS Index: flang/lib/Optimizer/CMakeLists.txt =================================================================== --- flang/lib/Optimizer/CMakeLists.txt +++ flang/lib/Optimizer/CMakeLists.txt @@ -6,14 +6,22 @@ Dialect/FIROps.cpp Dialect/FIRType.cpp + Support/FIRContext.cpp Support/InternalNames.cpp Support/KindMapping.cpp + CodeGen/CGOps.cpp + CodeGen/CodeGen.cpp + CodeGen/PreCGRewrite.cpp + CodeGen/Target.cpp + CodeGen/TargetRewrite.cpp + Transforms/Inliner.cpp DEPENDS FIROpsIncGen - FIROptTransformsPassIncGen + FIROptCodeGenPassIncGen + CGOpsIncGen ${dialect_libs} LINK_LIBS Index: flang/lib/Optimizer/CodeGen/CGOps.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.h @@ -0,0 +1,23 @@ +//===-- CGOps.h -------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_CODEGEN_CGOPS_H +#define OPTIMIZER_CODEGEN_CGOPS_H + +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.h.inc" + +#endif Index: flang/lib/Optimizer/CodeGen/CGOps.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.cpp @@ -0,0 +1,32 @@ +//===-- CGOps.cpp -- FIR codegen operations -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIRType.h" + +/// FIR codegen dialect constructor. +fir::FIRCodeGenDialect::FIRCodeGenDialect(mlir::MLIRContext *ctx) + : mlir::Dialect("fircg", ctx, mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" + >(); +} + +// anchor the class vtable to this compilation unit +fir::FIRCodeGenDialect::~FIRCodeGenDialect() { + // do nothing +} + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" Index: flang/lib/Optimizer/CodeGen/CodeGen.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -0,0 +1,2655 @@ +//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "CGOps.h" +#include "DescriptorModel.h" +#include "Target.h" +#include "flang/Lower/Todo.h" // remove when TODO's are done +#include "flang/Optimizer/Dialect/FIRAttr.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Support/KindMapping.h" +#include "flang/Optimizer/Support/TypeCode.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVMIR.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Config/abi-breaking.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "flang-codegen" + +//===----------------------------------------------------------------------===// +/// \file +/// +/// The Tilikum bridge performs the conversion of operations from both the FIR +/// and standard dialects to the LLVM-IR dialect. +/// +/// Some FIR operations may be lowered to other dialects, such as standard, but +/// some FIR operations will pass through to the Tilikum bridge. This may be +/// necessary to preserve the semantics of the Fortran program. +//===----------------------------------------------------------------------===// + +using namespace llvm; + +using OperandTy = ArrayRef; + +static cl::opt + disableFirToLLVMIR("disable-fir2llvmir", + cl::desc("disable FIR to LLVM-IR dialect pass"), + cl::init(false), cl::Hidden); + +static cl::opt disableLLVM("disable-llvm", cl::desc("disable LLVM pass"), + cl::init(false), cl::Hidden); + +namespace fir { +/// return true if all `Value`s in `operands` are `ConstantOp`s +bool allConstants(OperandTy operands) { + for (auto opnd : operands) { + if (auto defop = opnd.getDefiningOp()) + if (isa(defop) || isa(defop)) + continue; + return false; + } + return true; +} +} // namespace fir + +using SmallVecResult = SmallVector; +using AttributeTy = ArrayRef; + +// FIXME: This should really be recovered from the specified target. +static constexpr unsigned defaultAlign = 8; + +// fir::LLVMTypeConverter for converting to LLVM IR dialect types. +#include "TypeConverter.h" + +// Instantiate static data member of the type converter. +StringMap fir::LLVMTypeConverter::identStructCache; + +inline mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); +} + +namespace { +/// FIR conversion pattern template +template +class FIROpConversion : public mlir::OpConversionPattern { +public: + explicit FIROpConversion(mlir::MLIRContext *ctx, + fir::LLVMTypeConverter &lowering) + : mlir::OpConversionPattern(lowering, ctx, 1) {} + +protected: + mlir::Type convertType(mlir::Type ty) const { + return lowerTy().convertType(ty); + } + mlir::Type unwrap(mlir::Type ty) const { return lowerTy().unwrap(ty); } + mlir::Type voidPtrTy() const { + return getVoidPtrType(&lowerTy().getContext()); + } + + mlir::LLVM::ConstantOp + genConstantOffset(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + int offset) const { + auto ity = lowerTy().offsetType(); + auto cattr = rewriter.getI32IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); + } + + /// Method to construct code sequence to get the rank from a box. + mlir::Value getRankFromBox(mlir::Location loc, mlir::Value box, + mlir::Type resultTy, + mlir::ConversionPatternRewriter &rewriter) const { + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c3 = genConstantOffset(loc, rewriter, 3); + auto pty = mlir::LLVM::LLVMPointerType::get(unwrap(resultTy)); + auto p = rewriter.create(loc, pty, + mlir::ValueRange{box, c0, c3}); + return rewriter.create(loc, resultTy, p); + } + + /// Method to construct code sequence to get the triple for dimension `dim` + /// from a box. + SmallVector + getDimsFromBox(mlir::Location loc, ArrayRef retTys, + mlir::Value box, mlir::Value dim, + mlir::ConversionPatternRewriter &rewriter) const { + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c7 = genConstantOffset(loc, rewriter, 7); + auto l0 = loadFromOffset(loc, box, c0, c7, dim, 0, retTys[0], rewriter); + auto l1 = loadFromOffset(loc, box, c0, c7, dim, 1, retTys[1], rewriter); + auto l2 = loadFromOffset(loc, box, c0, c7, dim, 2, retTys[2], rewriter); + return {l0.getResult(), l1.getResult(), l2.getResult()}; + } + + mlir::LLVM::LoadOp + loadFromOffset(mlir::Location loc, mlir::Value a, mlir::LLVM::ConstantOp c0, + mlir::LLVM::ConstantOp c7, mlir::Value dim, int off, + mlir::Type ty, + mlir::ConversionPatternRewriter &rewriter) const { + auto pty = mlir::LLVM::LLVMPointerType::get(unwrap(ty)); + auto c = genConstantOffset(loc, rewriter, off); + auto p = genGEP(loc, pty, rewriter, a, c0, c7, dim, c); + return rewriter.create(loc, ty, p); + } + + template + mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty, + mlir::ConversionPatternRewriter &rewriter, + mlir::Value base, ARGS... args) const { + llvm::SmallVector cv{args...}; + return rewriter.create(loc, ty, base, cv); + } + + fir::LLVMTypeConverter &lowerTy() const { + return *static_cast(this->getTypeConverter()); + } +}; + +/// FIR conversion pattern template +template +class FIROpAndTypeConversion : public FIROpConversion { +public: + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(FromOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const final { + mlir::Type ty = this->convertType(op.getType()); + return doRewrite(op, ty, operands, rewriter); + } + + virtual mlir::LogicalResult + doRewrite(FromOp addr, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const { + llvm_unreachable("derived class must override"); + } +}; +} // namespace + +static Block *createBlock(mlir::ConversionPatternRewriter &rewriter, + mlir::Block *insertBefore) { + assert(insertBefore && "expected valid insertion block"); + return rewriter.createBlock(insertBefore->getParent(), + mlir::Region::iterator(insertBefore)); +} + +/// Create an LLVM dialect global +static void createGlobal(mlir::Location loc, mlir::ModuleOp mod, StringRef name, + mlir::Type type, + mlir::ConversionPatternRewriter &rewriter) { + if (mod.lookupSymbol(name)) + return; + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + modBuilder.create(loc, type, /*isConstant=*/true, + mlir::LLVM::Linkage::Weak, name, + mlir::Attribute{}); +} + +namespace { +struct AddrOfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AddrOfOp addr, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = unwrap(convertType(addr.getType())); + rewriter.replaceOpWithNewOp( + addr, ty, addr.symbol().getRootReference()); + return success(); + } +}; +} // namespace + +static mlir::LLVM::ConstantOp +genConstantIndex(mlir::Location loc, mlir::Type ity, + mlir::ConversionPatternRewriter &rewriter, int offset) { + auto cattr = rewriter.getI64IntegerAttr(offset); + return rewriter.create(loc, ity, cattr); +} + +namespace { +/// convert to LLVM IR dialect `alloca` +struct AllocaOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AllocaOp alloc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = alloc.getLoc(); + auto ity = lowerTy().indexType(); + auto c1 = genConstantIndex(loc, ity, rewriter, 1); + auto size = c1.getResult(); + for (auto opnd : operands) + size = rewriter.create(loc, ity, size, opnd); + auto ty = convertType(alloc.getType()); + rewriter.replaceOpWithNewOp(alloc, ty, size, + alloc.getAttrs()); + return success(); + } +}; +} // namespace + +/// Return the LLVMFuncOp corresponding to the standard malloc call. +static mlir::LLVM::LLVMFuncOp +getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto mallocFunc = module.lookupSymbol("malloc")) + return mallocFunc; + mlir::OpBuilder moduleBuilder( + op->getParentOfType().getBodyRegion()); + auto indexType = mlir::IntegerType::get(op.getContext(), 64); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "malloc", + mlir::LLVM::LLVMFunctionType::get(getVoidPtrType(op.getContext()), + indexType, + /*isVarArg=*/false)); +} + +namespace { +/// convert to `call` to the runtime to `malloc` memory +struct AllocMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AllocMemOp heap, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(heap.getType()); + auto mallocFunc = getMalloc(heap, rewriter); + auto loc = heap.getLoc(); + auto ity = lowerTy().indexType(); + auto size = genTypeSizeInBytes(loc, ity, rewriter, unwrap(ty)); + for (auto opnd : operands) + size = rewriter.create(loc, ity, size, opnd); + heap->setAttr("callee", rewriter.getSymbolRefAttr(mallocFunc)); + auto malloc = rewriter.create( + loc, getVoidPtrType(heap.getContext()), size, heap.getAttrs()); + rewriter.replaceOpWithNewOp(heap, ty, + malloc.getResult(0)); + return success(); + } + + // Compute the (allocation) size of the allocmem type in bytes. + mlir::Value genTypeSizeInBytes(mlir::Location loc, mlir::Type idxTy, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type llTy) const { + // Use the primitive size, if available. + auto ptrTy = llTy.dyn_cast(); + if (auto size = + mlir::LLVM::getPrimitiveTypeSizeInBits(ptrTy.getElementType())) + return genConstantIndex(loc, idxTy, rewriter, size / 8); + + // Otherwise, generate the GEP trick in LLVM IR to compute the size, since + // mlir::Type doesn't provide a sufficiently complete implementation. + auto nullPtr = rewriter.create(loc, ptrTy); + auto one = genConstantIndex(loc, lowerTy().offsetType(), rewriter, 1); + auto gep = rewriter.create( + loc, ptrTy, mlir::ValueRange{nullPtr, one}); + return rewriter.create(loc, idxTy, gep); + } +}; +} // namespace + +/// obtain the free() function +static mlir::LLVM::LLVMFuncOp +getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { + auto module = op->getParentOfType(); + if (auto freeFunc = module.lookupSymbol("free")) + return freeFunc; + mlir::OpBuilder moduleBuilder(module.getBodyRegion()); + auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); + return moduleBuilder.create( + rewriter.getUnknownLoc(), "free", + mlir::LLVM::LLVMFunctionType::get(voidType, + getVoidPtrType(op.getContext()), + /*isVarArg=*/false)); +} + +namespace { +/// lower a freemem instruction into a call to free() +struct FreeMemOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::FreeMemOp freemem, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto freeFunc = getFree(freemem, rewriter); + auto loc = freemem.getLoc(); + auto bitcast = rewriter.create( + freemem.getLoc(), voidPtrTy(), operands[0]); + freemem->setAttr("callee", rewriter.getSymbolRefAttr(freeFunc)); + rewriter.create( + loc, mlir::LLVM::LLVMVoidType::get(freemem.getContext()), + mlir::ValueRange{bitcast}, freemem.getAttrs()); + rewriter.eraseOp(freemem); + return success(); + } +}; + +/// convert to returning the first element of the box (any flavor) +struct BoxAddrOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxAddrOp boxaddr, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxaddr.getLoc(); + auto ty = convertType(boxaddr.getType()); + if (auto argty = boxaddr.val().getType().dyn_cast()) { + auto c0 = genConstantOffset(loc, rewriter, 0); + auto pty = mlir::LLVM::LLVMPointerType::get(unwrap(ty)); + auto p = genGEP(loc, unwrap(pty), rewriter, a, c0, c0); + // load the pointer from the buffer + rewriter.replaceOpWithNewOp(boxaddr, ty, p); + } else { + auto c0attr = rewriter.getI32IntegerAttr(0); + auto c0 = mlir::ArrayAttr::get(c0attr, boxaddr.getContext()); + rewriter.replaceOpWithNewOp(boxaddr, ty, a, + c0); + } + return success(); + } +}; + +/// convert to an extractvalue for the 2nd part of the boxchar +struct BoxCharLenOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxCharLenOp boxchar, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto ty = convertType(boxchar.getType()); + auto ctx = boxchar.getContext(); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + rewriter.replaceOpWithNewOp(boxchar, ty, a, c1); + return success(); + } +}; + +/// convert to a triple set of GEPs and loads +struct BoxDimsOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxDimsOp boxdims, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes = { + convertType(boxdims.getResult(0).getType()), + convertType(boxdims.getResult(1).getType()), + convertType(boxdims.getResult(2).getType()), + }; + auto results = getDimsFromBox(boxdims.getLoc(), resultTypes, operands[0], + operands[1], rewriter); + rewriter.replaceOp(boxdims, results); + return success(); + } +}; + +struct BoxEleSizeOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxEleSizeOp boxelesz, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxelesz.getLoc(); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c1 = genConstantOffset(loc, rewriter, 1); + auto ty = convertType(boxelesz.getType()); + auto pty = mlir::LLVM::LLVMPointerType::get(unwrap(ty)); + auto p = genGEP(loc, unwrap(pty), rewriter, a, c0, c1); + rewriter.replaceOpWithNewOp(boxelesz, ty, p); + return success(); + } +}; + +struct BoxIsAllocOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxIsAllocOp boxisalloc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxisalloc.getLoc(); + auto ity = lowerTy().offsetType(); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c5 = genConstantOffset(loc, rewriter, 5); + auto ty = convertType(boxisalloc.getType()); + auto p = genGEP(loc, unwrap(ty), rewriter, a, c0, c5); + auto ld = rewriter.create(loc, ty, p); + auto ab = genConstantOffset(loc, rewriter, 2); + auto bit = rewriter.create(loc, ity, ld, ab); + rewriter.replaceOpWithNewOp( + boxisalloc, mlir::LLVM::ICmpPredicate::ne, bit, c0); + return success(); + } +}; + +struct BoxIsArrayOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxIsArrayOp boxisarray, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxisarray.getLoc(); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c3 = genConstantOffset(loc, rewriter, 3); + auto ty = convertType(boxisarray.getType()); + auto p = genGEP(loc, unwrap(ty), rewriter, a, c0, c3); + auto ld = rewriter.create(loc, ty, p); + rewriter.replaceOpWithNewOp( + boxisarray, mlir::LLVM::ICmpPredicate::ne, ld, c0); + return success(); + } +}; + +struct BoxIsPtrOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxIsPtrOp boxisptr, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxisptr.getLoc(); + auto ty = convertType(boxisptr.getType()); + auto ity = lowerTy().offsetType(); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c5 = genConstantOffset(loc, rewriter, 5); + auto p = rewriter.create(loc, ty, + mlir::ValueRange{a, c0, c5}); + auto ld = rewriter.create(loc, ty, p); + auto ab = genConstantOffset(loc, rewriter, 1); + auto bit = rewriter.create(loc, ity, ld, ab); + rewriter.replaceOpWithNewOp( + boxisptr, mlir::LLVM::ICmpPredicate::ne, bit, c0); + return success(); + } +}; + +struct BoxProcHostOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxProcHostOp boxprochost, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto ty = convertType(boxprochost.getType()); + auto ctx = boxprochost.getContext(); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + rewriter.replaceOpWithNewOp(boxprochost, ty, a, + c1); + return success(); + } +}; + +struct BoxRankOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxRankOp boxrank, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxrank.getLoc(); + auto ty = convertType(boxrank.getType()); + auto result = getRankFromBox(loc, a, ty, rewriter); + rewriter.replaceOp(boxrank, result); + return success(); + } +}; + +struct BoxTypeDescOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::BoxTypeDescOp boxtypedesc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto loc = boxtypedesc.getLoc(); + auto ty = convertType(boxtypedesc.getType()); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto c4 = genConstantOffset(loc, rewriter, 4); + auto pty = mlir::LLVM::LLVMPointerType::get(unwrap(ty)); + auto p = rewriter.create(loc, pty, + mlir::ValueRange{a, c0, c4}); + auto ld = rewriter.create(loc, ty, p); + auto i8ptr = mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(boxtypedesc.getContext(), 8)); + rewriter.replaceOpWithNewOp(boxtypedesc, i8ptr, ld); + return success(); + } +}; + +struct StringLitOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::StringLitOp constop, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(constop.getType()); + auto attr = constop.getValue(); + if (attr.isa()) { + rewriter.replaceOpWithNewOp(constop, ty, attr); + } else { + // convert the array attr to a dense elements attr + // LLVMIR dialect knows how to lower the latter to LLVM IR + auto arr = attr.cast(); + auto size = constop.getSize().cast().getInt(); + auto charTy = constop.getType().cast(); + auto bits = lowerTy().characterBitsize(charTy); + auto intTy = rewriter.getIntegerType(bits); + auto det = mlir::VectorType::get({size}, intTy); + // convert each character to a precise bitsize + SmallVector vec; + for (auto a : arr.getValue()) + vec.push_back(mlir::IntegerAttr::get( + intTy, a.cast().getValue().sextOrTrunc(bits))); + auto dea = mlir::DenseElementsAttr::get(det, vec); + rewriter.replaceOpWithNewOp(constop, ty, dea); + } + return success(); + } +}; + +/// direct call LLVM function +struct CallOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::CallOp call, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + SmallVector resultTys; + for (auto r : call.getResults()) + resultTys.push_back(convertType(r.getType())); + rewriter.replaceOpWithNewOp(call, resultTys, operands, + call.getAttrs()); + return success(); + } +}; + +/// Compare complex values +/// +/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une). +/// +/// For completeness, all other comparison are done on the real component only. +struct CmpcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::CmpcOp cmp, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ctxt = cmp.getContext(); + auto kind = cmp.lhs().getType().cast().getFKind(); + auto ty = convertType(fir::RealType::get(ctxt, kind)); + auto resTy = convertType(cmp.getType()); + auto loc = cmp.getLoc(); + auto pos0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctxt); + SmallVector rp{ + rewriter.create(loc, ty, operands[0], pos0), + rewriter.create(loc, ty, operands[1], + pos0)}; + auto rcp = + rewriter.create(loc, resTy, rp, cmp.getAttrs()); + auto pos1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctxt); + SmallVector ip{ + rewriter.create(loc, ty, operands[0], pos1), + rewriter.create(loc, ty, operands[1], + pos1)}; + auto icp = + rewriter.create(loc, resTy, ip, cmp.getAttrs()); + SmallVector cp{rcp, icp}; + switch (cmp.getPredicate()) { + case mlir::CmpFPredicate::OEQ: // .EQ. + rewriter.replaceOpWithNewOp(cmp, resTy, cp); + break; + case mlir::CmpFPredicate::UNE: // .NE. + rewriter.replaceOpWithNewOp(cmp, resTy, cp); + break; + default: + rewriter.replaceOp(cmp, rcp.getResult()); + break; + } + return success(); + } +}; + +struct CmpfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::CmpfOp cmp, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto type = convertType(cmp.getType()); + rewriter.replaceOpWithNewOp(cmp, type, operands, + cmp.getAttrs()); + return success(); + } +}; + +struct ConstcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::ConstcOp conc, OperandTy, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = conc.getLoc(); + auto ctx = conc.getContext(); + auto ty = convertType(conc.getType()); + auto ct = conc.getType().cast(); + auto ety = lowerTy().convertComplexPartType(ct.getFKind()); + auto ri = mlir::FloatAttr::get(ety, getValue(conc.getReal())); + auto rp = rewriter.create(loc, ety, ri); + auto ii = mlir::FloatAttr::get(ety, getValue(conc.getImaginary())); + auto ip = rewriter.create(loc, ety, ii); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto r = rewriter.create(loc, ty); + auto rr = rewriter.create(loc, ty, r, rp, c0); + rewriter.replaceOpWithNewOp(conc, ty, rr, ip, + c1); + return success(); + } + + inline APFloat getValue(mlir::Attribute attr) const { + return attr.cast().getValue(); + } +}; + +struct ConstfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::ConstfOp conf, OperandTy, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(conf.getType()); + auto val = conf.constantAttr(); + rewriter.replaceOpWithNewOp(conf, ty, val); + return success(); + } +}; + +static mlir::Type getComplexEleTy(mlir::Type complex) { + if (auto cc = complex.dyn_cast()) + return cc.getElementType(); + return complex.cast().getElementType(); +} + +/// convert value of from-type to value of to-type +struct ConvertOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + static bool isFloatingPointTy(mlir::Type ty) { + return ty.isa(); + } + + mlir::LogicalResult + matchAndRewrite(fir::ConvertOp convert, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto fromTy_ = convertType(convert.value().getType()); + auto fromTy = unwrap(fromTy_); + auto toTy_ = convertType(convert.res().getType()); + auto toTy = unwrap(toTy_); + auto &op0 = operands[0]; + if (fromTy == toTy) { + rewriter.replaceOp(convert, op0); + return success(); + } + auto loc = convert.getLoc(); + auto convertFpToFp = [&](mlir::Value val, unsigned fromBits, + unsigned toBits, mlir::Type toTy) -> mlir::Value { + if (fromBits == toBits) { + // TODO: Converting between two floating-point representations with the + // same bitwidth is not allowed for now. + mlir::emitError(loc, + "cannot implicitly convert between two floating-point " + "representations of the same bitwidth"); + return {}; + } + if (fromBits > toBits) + return rewriter.create(loc, toTy, val); + return rewriter.create(loc, toTy, val); + }; + if (fir::isa_complex(convert.value().getType()) && + fir::isa_complex(convert.res().getType())) { + // Special case: handle the conversion of a complex such that both the + // real and imaginary parts are converted together. + auto zero = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), + convert.getContext()); + auto one = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), + convert.getContext()); + auto ty = convertType(getComplexEleTy(convert.value().getType())); + auto rp = rewriter.create(loc, ty, op0, zero); + auto ip = rewriter.create(loc, ty, op0, one); + auto nt = convertType(getComplexEleTy(convert.res().getType())); + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(unwrap(ty)); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(unwrap(nt)); + auto rc = convertFpToFp(rp, fromBits, toBits, nt); + auto ic = convertFpToFp(ip, fromBits, toBits, nt); + auto un = rewriter.create(loc, toTy_); + auto i1 = + rewriter.create(loc, toTy_, un, rc, zero); + rewriter.replaceOpWithNewOp(convert, toTy_, i1, + ic, one); + return mlir::success(); + } + if (isFloatingPointTy(fromTy)) { + if (isFloatingPointTy(toTy)) { + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); + auto v = convertFpToFp(op0, fromBits, toBits, toTy); + rewriter.replaceOp(convert, v); + return mlir::success(); + } + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } else if (fromTy.isa()) { + if (toTy.isa()) { + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); + assert(fromBits != toBits); + if (fromBits > toBits) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + if (isFloatingPointTy(toTy)) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } else if (fromTy.isa()) { + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } + return emitError(loc) << "cannot convert " << fromTy_ << " to " << toTy_; + } +}; + +/// virtual call to a method in a dispatch table +struct DispatchOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::DispatchOp dispatch, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(dispatch.getFunctionType()); + // get the table, lookup the method, fetch the func-ptr + rewriter.replaceOpWithNewOp(dispatch, ty, operands, + None); + TODO(""); + return success(); + } +}; + +/// dispatch table for a Fortran derived type +struct DispatchTableOpConversion + : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::DispatchTableOp dispTab, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + TODO(""); + return success(); + } +}; + +/// entry in a dispatch table; binds a method-name to a function +struct DTEntryOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::DTEntryOp dtEnt, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + TODO(""); + return success(); + } +}; + +/// Perform an extension or truncation as needed on an integer value. Lowering +/// to the specific target may involve some sign-extending or truncation of +/// values, particularly to fit them from abstract box types to the appropriate +/// reified structures. +static mlir::Value integerCast(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type ty, mlir::Value val) { + auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto fromSize = + mlir::LLVM::getPrimitiveTypeSizeInBits(val.getType().cast()); + if (toSize < fromSize) + return rewriter.create(loc, ty, val); + if (toSize > fromSize) + return rewriter.create(loc, ty, val); + return val; +} + +/// create a CHARACTER box +struct EmboxCharOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::EmboxCharOp emboxChar, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto a = operands[0]; + auto b1 = operands[1]; + auto loc = emboxChar.getLoc(); + auto ctx = emboxChar.getContext(); + auto ty = convertType(emboxChar.getType()); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto un = rewriter.create(loc, ty); + auto lenTy = unwrap(ty).cast().getBody()[1]; + auto b = integerCast(loc, rewriter, lenTy, b1); + auto r = rewriter.create(loc, ty, un, a, c0); + rewriter.replaceOpWithNewOp(emboxChar, ty, r, b, + c1); + return success(); + } +}; + +// Common base class for lowering of embox to descriptor creation. +template +struct EmboxCommonConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + /// Generate an alloca of size `size` and cast it to type `toTy` + mlir::LLVM::AllocaOp + genAllocaWithType(mlir::Location loc, mlir::Type toTy, unsigned alignment, + mlir::ConversionPatternRewriter &rewriter) const { + auto thisPt = rewriter.saveInsertionPoint(); + auto *thisBlock = rewriter.getInsertionBlock(); + auto func = mlir::cast(thisBlock->getParentOp()); + rewriter.setInsertionPointToStart(&func.front()); + auto sz = this->genConstantOffset(loc, rewriter, 1); + auto al = rewriter.create(loc, toTy, sz, alignment); + rewriter.restoreInsertionPoint(thisPt); + return al; + } + + // Get the element type given an LLVM type that is of the form + // [llvm.ptr](llvm.array|llvm.struct)+ and the provided indexes. + static mlir::Type getBoxEleTy(mlir::Type type, + llvm::ArrayRef indexes) { + if (auto t = type.dyn_cast()) + type = t.getElementType(); + for (auto i : indexes) { + if (auto t = type.dyn_cast()) { + assert(!t.isOpaque() && i < t.getBody().size()); + type = t.getBody()[i]; + } else if (auto t = type.dyn_cast()) { + type = t.getElementType(); + } else if (auto t = type.dyn_cast()) { + type = t.getElementType(); + } + } + return type; + } + + int getCFIAttr(fir::BoxType boxTy) const { + auto eleTy = boxTy.getEleTy(); + if (eleTy.isa()) + return CFI_attribute_pointer; + if (eleTy.isa()) + return CFI_attribute_allocatable; + return CFI_attribute_other; + } + + bool isDerivedType(fir::BoxType boxTy) const { + return boxTy.getEleTy().isa(); + } + + // Get the element size and CFI type code of the boxed value. + std::tuple getSizeAndTypeCode( + mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Type boxEleTy, mlir::ValueRange lenParams = {}) const { + auto doInteger = + [&](unsigned width) -> std::tuple { + int typeCode = fir::integerBitsToTypeCode(width); + return {this->genConstantOffset(loc, rewriter, width / 8), + this->genConstantOffset(loc, rewriter, typeCode)}; + }; + auto doLogical = + [&](unsigned width) -> std::tuple { + int typeCode = fir::logicalBitsToTypeCode(width); + return {this->genConstantOffset(loc, rewriter, width / 8), + this->genConstantOffset(loc, rewriter, typeCode)}; + }; + auto doFloat = [&](unsigned width) -> std::tuple { + int typeCode = fir::realBitsToTypeCode(width); + return {this->genConstantOffset(loc, rewriter, width / 8), + this->genConstantOffset(loc, rewriter, typeCode)}; + }; + auto doComplex = + [&](unsigned width) -> std::tuple { + auto typeCode = fir::complexBitsToTypeCode(width); + return {this->genConstantOffset(loc, rewriter, width / 8 * 2), + this->genConstantOffset(loc, rewriter, typeCode)}; + }; + auto doCharacter = + [&](unsigned width, + mlir::Value len) -> std::tuple { + auto typeCode = fir::characterBitsToTypeCode(width); + auto typeCodeVal = this->genConstantOffset(loc, rewriter, typeCode); + if (width == 8) + return {len, typeCodeVal}; + auto byteWidth = this->genConstantOffset(loc, rewriter, width / 8); + auto i64Ty = mlir::IntegerType::get(&this->lowerTy().getContext(), 64); + auto size = + rewriter.create(loc, i64Ty, byteWidth, len); + return {size, typeCodeVal}; + }; + auto getKindMap = [&]() -> fir::KindMapping & { + return this->lowerTy().getKindMap(); + }; + + if (auto eleTy = fir::dyn_cast_ptrEleTy(boxEleTy)) + boxEleTy = eleTy; + if (fir::isa_integer(boxEleTy)) { + if (auto ty = boxEleTy.dyn_cast()) + return doInteger(ty.getWidth()); + auto ty = boxEleTy.cast(); + return doInteger(getKindMap().getIntegerBitsize(ty.getFKind())); + } + if (fir::isa_real(boxEleTy)) { + if (auto ty = boxEleTy.dyn_cast()) + return doFloat(ty.getWidth()); + auto ty = boxEleTy.cast(); + return doFloat(getKindMap().getRealBitsize(ty.getFKind())); + } + if (fir::isa_complex(boxEleTy)) { + if (auto ty = boxEleTy.dyn_cast()) + return doComplex( + ty.getElementType().cast().getWidth()); + auto ty = boxEleTy.cast(); + return doComplex(getKindMap().getRealBitsize(ty.getFKind())); + } + if (auto ty = boxEleTy.dyn_cast()) { + auto charWidth = getKindMap().getCharacterBitsize(ty.getFKind()); + if (ty.getLen() != fir::CharacterType::unknownLen()) { + auto len = this->genConstantOffset(loc, rewriter, ty.getLen()); + return doCharacter(charWidth, len); + } + assert(!lenParams.empty()); + return doCharacter(charWidth, lenParams[0]); + } + if (auto ty = boxEleTy.dyn_cast()) + return doLogical(getKindMap().getLogicalBitsize(ty.getFKind())); + if (auto seqTy = boxEleTy.dyn_cast()) { + return getSizeAndTypeCode(loc, rewriter, seqTy.getEleTy(), lenParams); + } + if (boxEleTy.isa()) { + TODO(""); + } + if (fir::isa_ref_type(boxEleTy)) { + // FIXME: use the target pointer size rather than sizeof(void*) + return {this->genConstantOffset(loc, rewriter, sizeof(void *)), + this->genConstantOffset(loc, rewriter, CFI_type_cptr)}; + } + // fail: unhandled case + TODO(""); + } + + /// Basic pattern to write a field in the descriptor + mlir::Value insertField(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Value dest, + llvm::ArrayRef fldIndexes, + mlir::Value value, bool bitcast = false) const { + auto boxTy = this->unwrap(dest.getType()); + auto fldTy = this->getBoxEleTy(boxTy, fldIndexes); + if (bitcast) + value = rewriter.create(loc, fldTy, value); + else + value = integerCast(loc, rewriter, fldTy, value); + llvm::SmallVector attrs; + for (auto i : fldIndexes) + attrs.push_back(rewriter.getI32IntegerAttr(i)); + auto indexesAttr = mlir::ArrayAttr::get(attrs, rewriter.getContext()); + return rewriter.create(loc, boxTy, dest, value, + indexesAttr); + } + + template + std::tuple + consDescriptorPrefix(BOX box, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter, unsigned rank, + unsigned dropFront) const { + auto loc = box.getLoc(); + auto boxTy = box.getType().template dyn_cast(); + auto convTy = this->lowerTy().convertBoxType(boxTy, rank); + auto llvmBoxPtrTy = + this->unwrap(convTy).template cast(); + auto llvmBoxTy = llvmBoxPtrTy.getElementType(); + mlir::Value dest = rewriter.create(loc, llvmBoxTy); + + // Write each of the fields with the appropriate values + dest = insertField(rewriter, loc, dest, {0}, operands[0], /*bitCast=*/true); + auto [eleSize, cfiTy] = getSizeAndTypeCode(loc, rewriter, boxTy.getEleTy(), + operands.drop_front(dropFront)); + dest = insertField(rewriter, loc, dest, {1}, eleSize); + dest = insertField(rewriter, loc, dest, {2}, + this->genConstantOffset(loc, rewriter, CFI_VERSION)); + dest = insertField(rewriter, loc, dest, {3}, + this->genConstantOffset(loc, rewriter, rank)); + dest = insertField(rewriter, loc, dest, {4}, cfiTy); + dest = + insertField(rewriter, loc, dest, {5}, + this->genConstantOffset(loc, rewriter, getCFIAttr(boxTy))); + dest = insertField( + rewriter, loc, dest, {6}, + this->genConstantOffset(loc, rewriter, isDerivedType(boxTy))); + return {boxTy, dest, eleSize}; + } + + /// If the embox is not in a globalOp body, allocate storage for the box and + /// store the value inside. Return the input value otherwise. + mlir::Value + placeInMemoryIfNotGlobalInit(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Value boxValue) const { + auto *thisBlock = rewriter.getInsertionBlock(); + if (thisBlock && mlir::isa(thisBlock->getParentOp())) + return boxValue; + auto boxPtrTy = + mlir::LLVM::LLVMPointerType::get(this->unwrap(boxValue.getType())); + auto alloca = genAllocaWithType(loc, boxPtrTy, defaultAlign, rewriter); + rewriter.create(loc, boxValue, alloca); + return alloca; + } +}; + +/// Create a generic box on a memory reference. This conversions lowers the +/// abstract box to the appropriate, initialized descriptor. +struct EmboxOpConversion : public EmboxCommonConversion { + using EmboxCommonConversion::EmboxCommonConversion; + + mlir::LogicalResult + matchAndRewrite(fir::EmboxOp embox, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // There should be no dims on this embox op + assert(!embox.getShape()); + auto [boxTy, dest, eleSize] = consDescriptorPrefix( + embox, operands, rewriter, /*rank=*/0, /*dropFront=*/1); + if (isDerivedType(boxTy)) + TODO("derived type"); + auto result = placeInMemoryIfNotGlobalInit(rewriter, embox.getLoc(), dest); + rewriter.replaceOp(embox, result); + return success(); + } +}; + +/// create a generic box on a memory reference +struct XEmboxOpConversion : public EmboxCommonConversion { + using EmboxCommonConversion::EmboxCommonConversion; + + mlir::LogicalResult + matchAndRewrite(fir::cg::XEmboxOp xbox, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto rank = xbox.getRank(); + auto [boxTy, dest, eleSize] = consDescriptorPrefix( + xbox, operands, rewriter, rank, xbox.lenParamOffset()); + // Generate the triples in the dims field of the descriptor + auto i64Ty = mlir::IntegerType::get(xbox.getContext(), 64); + assert(xbox.shape().size() && "must have a shape"); + unsigned shapeOff = 1; + bool hasShift = xbox.shift().size(); + unsigned shiftOff = shapeOff + xbox.shape().size(); + bool hasSlice = xbox.slice().size(); + unsigned sliceOff = shiftOff + xbox.shift().size(); + auto loc = xbox.getLoc(); + mlir::Value zero = genConstantIndex(loc, i64Ty, rewriter, 0); + mlir::Value one = genConstantIndex(loc, i64Ty, rewriter, 1); + mlir::Value prevDim = integerCast(loc, rewriter, i64Ty, eleSize); + auto eleTy = boxTy.getEleTy(); + for (unsigned d = 0; d < rank; ++d) { + // store lower bound (normally 0) + mlir::Value lb = zero; + if (eleTy.isa() || eleTy.isa() || + hasSlice) { + lb = one; + if (hasShift) + lb = operands[shiftOff]; + if (hasSlice) + lb = rewriter.create(loc, i64Ty, lb, + operands[sliceOff]); + } + dest = insertField(rewriter, loc, dest, {7, d, 0}, lb); + + // store extent + mlir::Value extent = operands[shapeOff]; + mlir::Value outerExtent = extent; + if (hasSlice) { + extent = rewriter.create( + loc, i64Ty, operands[sliceOff + 1], operands[sliceOff]); + extent = rewriter.create(loc, i64Ty, extent, + operands[sliceOff + 2]); + extent = rewriter.create(loc, i64Ty, extent, + operands[sliceOff + 2]); + } + dest = insertField(rewriter, loc, dest, {7, d, 1}, extent); + + // store step (scaled by shaped extent) + mlir::Value step = prevDim; + if (hasSlice) + step = rewriter.create(loc, i64Ty, step, + operands[sliceOff + 2]); + dest = insertField(rewriter, loc, dest, {7, d, 2}, step); + // compute the stride for the next natural dimension + prevDim = + rewriter.create(loc, i64Ty, prevDim, outerExtent); + + // increment iterators + shapeOff++; + if (hasShift) + shiftOff++; + if (hasSlice) + sliceOff += 3; + } + if (isDerivedType(boxTy)) + TODO("derived type"); + + auto result = placeInMemoryIfNotGlobalInit(rewriter, xbox.getLoc(), dest); + rewriter.replaceOp(xbox, result); + return success(); + } +}; + +/// create a procedure pointer box +struct EmboxProcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::EmboxProcOp emboxproc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = emboxproc.getLoc(); + auto ctx = emboxproc.getContext(); + auto ty = convertType(emboxproc.getType()); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto un = rewriter.create(loc, ty); + auto r = rewriter.create(loc, ty, un, + operands[0], c0); + rewriter.replaceOpWithNewOp(emboxproc, ty, r, + operands[1], c1); + return success(); + } +}; + +// Code shared between insert_value and extract_value Ops. +struct ValueOpCommon { + static mlir::Attribute getValue(mlir::Value value) { + auto defOp = value.getDefiningOp(); + if (auto v = dyn_cast(defOp)) + return v.value(); + if (auto v = dyn_cast(defOp)) + return v.value(); + llvm_unreachable("must be a constant op"); + return {}; + } + + // Translate the arguments pertaining to any multidimensional array to + // row-major order for LLVM-IR. + static void toRowMajor(SmallVectorImpl &attrs, + mlir::Type ty) { + assert(ty && "type is null"); + const auto end = attrs.size(); + for (std::remove_const_t i = 0; i < end; ++i) { + if (auto seq = ty.dyn_cast()) { + const auto dim = getDimension(seq); + if (dim > 1) { + auto ub = std::min(i + dim, end); + std::reverse(attrs.begin() + i, attrs.begin() + ub); + i += dim - 1; + } + ty = getArrayElementType(seq); + } else if (auto st = ty.dyn_cast()) { + ty = st.getBody()[attrs[i].cast().getInt()]; + } else { + llvm_unreachable("index into invalid type"); + } + } + } + +private: + static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { + unsigned result = 1; + for (auto eleTy = ty.getElementType().dyn_cast(); + eleTy; + eleTy = eleTy.getElementType().dyn_cast()) + ++result; + return result; + } + + static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) { + auto eleTy = ty.getElementType(); + while (auto arrTy = eleTy.dyn_cast()) + eleTy = arrTy.getElementType(); + return eleTy; + } +}; + +/// Extract a subobject value from an ssa-value of aggregate type +struct ExtractValueOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!fir::allConstants(operands.drop_front(1))) + llvm_unreachable("fir.extract_value incorrectly formed"); + // since all indices are constants use LLVM's extractvalue instruction + SmallVector attrs; + for (std::size_t i = 1, end{operands.size()}; i < end; ++i) + attrs.push_back(getValue(operands[i])); + toRowMajor(attrs, lowerTy().unwrap(operands[0].getType())); + auto position = mlir::ArrayAttr::get(attrs, extractVal.getContext()); + rewriter.replaceOpWithNewOp( + extractVal, ty, operands[0], position); + return success(); + } +}; + +/// InsertValue is the generalized instruction for the composition of new +/// aggregate type values. +struct InsertValueOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(fir::allConstants(operands.drop_front(2))); + // since all indices must be constants use LLVM's insertvalue instruction + SmallVector attrs; + for (std::size_t i = 2, end{operands.size()}; i < end; ++i) + attrs.push_back(getValue(operands[i])); + toRowMajor(attrs, lowerTy().unwrap(operands[0].getType())); + auto position = mlir::ArrayAttr::get(attrs, insertVal.getContext()); + rewriter.replaceOpWithNewOp( + insertVal, ty, operands[0], operands[1], position); + return success(); + } +}; + +/// InsertOnRange inserts a value into a sequence over a range of offsets. +struct InsertOnRangeOpConversion + : public FIROpAndTypeConversion, + public ValueOpCommon { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + // Increments an array of subscripts in a row major fasion. + void incrementSubscripts(const SmallVector &dims, + SmallVector &subscripts) const { + for (size_t i = dims.size(); i > 0; --i) { + if (++subscripts[i - 1] < dims[i - 1]) { + return; + } + subscripts[i - 1] = 0; + } + } + + mlir::LogicalResult + doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(fir::allConstants(operands.drop_front(2))); + + llvm::SmallVector lowerBound; + llvm::SmallVector upperBound; + llvm::SmallVector dims; + auto type = operands[0].getType(); + + // Iterativly extract the array dimensions from it's type. + while (auto t = type.dyn_cast()) { + dims.push_back(t.getNumElements()); + type = t.getElementType(); + } + + // Unzip the upper and lower bound subscripts. + for (std::size_t i = 2; i + 1 < operands.size(); i += 2) { + lowerBound.push_back(ExtractValueOpConversion::getValue(operands[i])); + upperBound.push_back(ExtractValueOpConversion::getValue(operands[i + 1])); + } + + llvm::SmallVector lBounds; + llvm::SmallVector uBounds; + + // Extract the integer value from the attribute bounds and convert to row + // major format. + for (size_t i = lowerBound.size(); i > 0; --i) { + lBounds.push_back(lowerBound[i - 1].cast().getInt()); + uBounds.push_back(upperBound[i - 1].cast().getInt()); + } + + auto subscripts(lBounds); + auto loc = range.getLoc(); + mlir::Value lastOp = operands[0]; + mlir::Value insertVal = operands[1]; + + while (subscripts != uBounds) { + // Convert uint64_t's to Attribute's. + llvm::SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + lastOp = rewriter.create( + loc, ty, lastOp, insertVal, + ArrayAttr::get(arrayRef, range.getContext())); + + incrementSubscripts(dims, subscripts); + } + + // Convert uint64_t's to Attribute's. + llvm::SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + + rewriter.replaceOpWithNewOp( + range, ty, lastOp, insertVal, + ArrayAttr::get(arrayRef, range.getContext())); + + return success(); + } +}; + +/// XArrayCoor is the address arithmetic on a dynamically shaped, etc. array. +/// (See the static restriction on coordinate_of.) array_coor determines the +/// coordinate (location) of a specific element. +struct XArrayCoorOpConversion + : public FIROpAndTypeConversion { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::cg::XArrayCoorOp coor, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = coor.getLoc(); + auto rank = coor.getRank(); + assert(coor.indices().size() == rank); + assert(coor.shape().size() == 0 || coor.shape().size() == rank); + assert(coor.shift().size() == 0 || coor.shift().size() == rank); + assert(coor.slice().size() == 0 || coor.slice().size() == 3 * rank); + auto indexOps = coor.indices().begin(); + auto shapeOps = coor.shape().begin(); + auto shiftOps = coor.shift().begin(); + auto sliceOps = coor.slice().begin(); + auto idxTy = lowerTy().indexType(); + mlir::Value base; + if (coor.subcomponent().empty()) { + // Cast the base address to a pointer to T + base = rewriter.create(loc, ty, operands[0]); + } else { + // operands[0] must have a pointer type. For subcomponent slicing, we + // want to cast away the array type and have a plain struct type. + auto ty0 = unwrap(operands[0].getType()); + auto ptrTy = ty0.dyn_cast(); + assert(ptrTy && "expected pointer type"); + auto eleTy = ptrTy.getElementType(); + if (auto arrTy = eleTy.dyn_cast()) + eleTy = arrTy.getElementType(); + auto newTy = mlir::LLVM::LLVMPointerType::get(eleTy); + base = rewriter.create(loc, newTy, operands[0]); + } + mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + auto prevExt = one; + mlir::Value off = genConstantIndex(loc, idxTy, rewriter, 0); + const bool isShifted = coor.shift().size() != 0; + const bool isSliced = coor.slice().size() != 0; + for (unsigned i = 0; i < rank; + ++i, ++indexOps, ++shapeOps, ++shiftOps, ++sliceOps) { + auto index = asType(loc, rewriter, idxTy, *indexOps); + auto nextExt = asType(loc, rewriter, idxTy, *shapeOps); + mlir::Value lb = one; + if (isShifted) + lb = asType(loc, rewriter, idxTy, *shiftOps); + mlir::Value step = one; + if (isSliced) + step = asType(loc, rewriter, idxTy, *(sliceOps + 2)); + auto idx = rewriter.create(loc, idxTy, index, lb); + mlir::Value diff = + rewriter.create(loc, idxTy, idx, step); + if (isSliced) { + auto sliceLb = asType(loc, rewriter, idxTy, *sliceOps); + auto adj = rewriter.create(loc, idxTy, sliceLb, lb); + diff = rewriter.create(loc, idxTy, diff, adj); + } + auto sc = rewriter.create(loc, idxTy, diff, prevExt); + off = rewriter.create(loc, idxTy, sc, off); + prevExt = + rewriter.create(loc, idxTy, prevExt, nextExt); + } + SmallVector args{base, off}; + args.append(coor.subcomponent().begin(), coor.subcomponent().end()); + rewriter.replaceOpWithNewOp(coor, ty, args); + return success(); + } + + mlir::Value asType(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, mlir::Type toTy, + mlir::Value val) const { + auto fromTy = unwrap(convertType(val.getType())); + auto fTy = fromTy.dyn_cast(); + auto tTy = toTy.dyn_cast(); + assert(fTy && tTy && "must be IntegerType"); + auto fbw = mlir::LLVM::getPrimitiveTypeSizeInBits(fTy); + auto tbw = mlir::LLVM::getPrimitiveTypeSizeInBits(tTy); + if (fbw < tbw) + return rewriter.create(loc, toTy, val); + if (fbw > tbw) + return rewriter.create(loc, toTy, val); + return val; + } +}; + +/// Convert to (memory) reference to a reference to a subobject. +/// The coordinate_of op is a Swiss army knife operation that can be used on +/// (memory) references to records, arrays, complex, etc. as well as boxes. +/// With unboxed arrays, there is the restriction that the array have a static +/// shape in all but the last column. +struct CoordinateOpConversion + : public FIROpAndTypeConversion { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::CoordinateOp coor, mlir::Type ty, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = coor.getLoc(); + auto c0 = genConstantIndex(loc, lowerTy().indexType(), rewriter, 0); + mlir::Value base = operands[0]; + auto firTy = coor.getBaseType(); + mlir::Type cpnTy = getReferenceEleTy(firTy); + bool columnIsDeferred = false; + bool hasSubdimension = hasSubDimensions(cpnTy); + + // if argument 0 is complex, get the real or imaginary part + if (fir::isa_complex(cpnTy)) { + SmallVector offs = {c0}; + offs.append(std::next(operands.begin()), operands.end()); + mlir::Value gep = genGEP(loc, unwrap(ty), rewriter, base, offs); + rewriter.replaceOp(coor, gep); + return success(); + } + + // if argument 0 is boxed, get the base pointer from the box + if (auto boxTy = firTy.dyn_cast()) { + + // Special case: + // %lenp = len_param_index foo, type + // %addr = coordinate_of %box, %lenp + if (coor.getNumOperands() == 2) { + auto coorPtr = *coor.coor().begin(); + auto s = coorPtr.getDefiningOp(); + if (s && isa(s)) { + mlir::Value lenParam = operands[1]; // byte offset + auto bc = + rewriter.create(loc, voidPtrTy(), base); + auto gep = genGEP(loc, unwrap(ty), rewriter, bc, lenParam); + rewriter.replaceOpWithNewOp(coor, unwrap(ty), + gep); + return success(); + } + } + + auto c0_ = genConstantOffset(loc, rewriter, 0); + auto pty = mlir::LLVM::LLVMPointerType::get( + unwrap(convertType(boxTy.getEleTy()))); + // Extract the boxed reference + auto p = genGEP(loc, pty, rewriter, base, c0, c0_); + // base = box->data : ptr + base = rewriter.create(loc, pty, p); + + // If the base has dynamic shape, it has to be boxed as the dimension + // information is saved in the box. + if (fir::LLVMTypeConverter::dynamicallySized(cpnTy)) { + TODO(""); + return success(); + } + } else { + if (fir::LLVMTypeConverter::dynamicallySized(cpnTy)) + return mlir::emitError(loc, "bare reference to unknown shape"); + } + if (!hasSubdimension) + columnIsDeferred = true; + + if (!validCoordinate(cpnTy, operands.drop_front(1))) + return mlir::emitError(loc, "coordinate has incorrect dimension"); + + // if arrays has known shape + const bool hasKnownShape = + arraysHaveKnownShape(cpnTy, operands.drop_front(1)); + + // If only the column is `?`, then we can simply place the column value in + // the 0-th GEP position. + if (auto arrTy = cpnTy.dyn_cast()) { + if (!hasKnownShape) { + const auto sz = arrTy.getDimension(); + if (arraysHaveKnownShape(arrTy.getEleTy(), + operands.drop_front(1 + sz))) { + auto shape = arrTy.getShape(); + bool allConst = true; + for (std::remove_const_t i = 0; i < sz - 1; ++i) + if (shape[i] < 0) { + allConst = false; + break; + } + if (allConst) + columnIsDeferred = true; + } + } + } + + if (hasKnownShape || columnIsDeferred) { + SmallVector offs; + if (hasKnownShape && hasSubdimension) + offs.push_back(c0); + const auto sz = operands.size(); + llvm::Optional dims; + SmallVector arrIdx; + for (std::remove_const_t i = 1; i < sz; ++i) { + auto nxtOpnd = operands[i]; + + if (!cpnTy) + return mlir::emitError(loc, "invalid coordinate/check failed"); + + // check if the i-th coordinate relates to an array + if (dims.hasValue()) { + arrIdx.push_back(nxtOpnd); + int dimsLeft = *dims; + if (dimsLeft > 1) { + dims = dimsLeft - 1; + continue; + } + cpnTy = cpnTy.cast().getEleTy(); + // append array range in reverse (FIR arrays are column-major) + offs.append(arrIdx.rbegin(), arrIdx.rend()); + arrIdx.clear(); + dims.reset(); + continue; + } else if (auto arrTy = cpnTy.dyn_cast()) { + int d = arrTy.getDimension() - 1; + if (d > 0) { + dims = d; + arrIdx.push_back(nxtOpnd); + continue; + } + cpnTy = cpnTy.cast().getEleTy(); + offs.push_back(nxtOpnd); + continue; + } + + // check if the i-th coordinate relates to a field + if (auto strTy = cpnTy.dyn_cast()) { + cpnTy = strTy.getType(getIntValue(nxtOpnd)); + } else if (auto strTy = cpnTy.dyn_cast()) { + cpnTy = strTy.getType(getIntValue(nxtOpnd)); + } else { + cpnTy = nullptr; + } + offs.push_back(nxtOpnd); + } + if (dims.hasValue()) + offs.append(arrIdx.rbegin(), arrIdx.rend()); + mlir::Value retval = genGEP(loc, unwrap(ty), rewriter, base, offs); + rewriter.replaceOp(coor, retval); + return success(); + } + + // Taking a coordinate of an array with deferred shape. In this case, the + // array must be boxed. We need to retrieve the array triples from the box. + // + // Given: + // + // %box ... : box> + // %addr = coordinate_of %box, %0, %1, %2 + // + // We want to lower this into an llvm GEP as: + // + // %i1 = (%0 - %box.dims(0).lo) * %box.dims(0).str + // %i2 = (%1 - %box.dims(1).lo) * %box.dims(1).str * %box.dims(0).ext + // %scale_by = %box.dims(1).ext * %box.dims(0).ext + // %i3 = (%2 - %box.dims(2).lo) * %box.dims(2).str * %scale_by + // %offset = %i3 + %i2 + %i1 + // %addr = getelementptr i32, i32* %box.ref, i64 %offset + // + // Section 18.5.3 para 3 specifies when and how to interpret the `lo` + // value(s) of the triple. The implication is that they must always be + // zero for `coordinate_of`. This is because we do not use `coordinate_of` + // to compute the offset into a `box` or `box`. The coordinate + // is pointer arithmetic. Pointers along a path must be explicitly + // dereferenced with a `load`. + + if (!firTy.isa()) + return mlir::emitError(loc, "base must have box type"); + if (!cpnTy.isa()) + return mlir::emitError(loc, "base element must be reference to array"); + auto baseTy = cpnTy.cast(); + const auto baseDim = baseTy.getDimension(); + if (!arraysHaveKnownShape(baseTy.getEleTy(), + operands.drop_front(1 + baseDim))) + return mlir::emitError(loc, "base element has deferred shapes"); + + // Generate offset computation. + TODO(""); + + return failure(); + } + + bool hasSubDimensions(mlir::Type type) const { + return type.isa() || type.isa() || + type.isa(); + } + + /// Walk the abstract memory layout and determine if the path traverses any + /// array types with unknown shape. Return true iff all the array types have a + /// constant shape along the path. + bool arraysHaveKnownShape(mlir::Type type, OperandTy coors) const { + const auto sz = coors.size(); + std::remove_const_t i = 0; + for (; i < sz; ++i) { + auto nxtOpnd = coors[i]; + if (auto arrTy = type.dyn_cast()) { + if (fir::LLVMTypeConverter::unknownShape(arrTy.getShape())) + return false; + i += arrTy.getDimension() - 1; + type = arrTy.getEleTy(); + } else if (auto strTy = type.dyn_cast()) { + type = strTy.getType(getIntValue(nxtOpnd)); + } else if (auto strTy = type.dyn_cast()) { + type = strTy.getType(getIntValue(nxtOpnd)); + } else { + return true; + } + } + return true; + } + + bool validCoordinate(mlir::Type type, OperandTy coors) const { + const auto sz = coors.size(); + std::remove_const_t i = 0; + bool subEle = false; + bool ptrEle = false; + for (; i < sz; ++i) { + auto nxtOpnd = coors[i]; + if (auto arrTy = type.dyn_cast()) { + subEle = true; + i += arrTy.getDimension() - 1; + type = arrTy.getEleTy(); + } else if (auto strTy = type.dyn_cast()) { + subEle = true; + type = strTy.getType(getIntValue(nxtOpnd)); + } else if (auto strTy = type.dyn_cast()) { + subEle = true; + type = strTy.getType(getIntValue(nxtOpnd)); + } else { + ptrEle = true; + } + } + if (ptrEle) + return (!subEle) && (sz == 1); + return subEle && (i >= sz); + } + + /// Returns the element type of the reference `refTy`. + static mlir::Type getReferenceEleTy(mlir::Type refTy) { + if (auto boxTy = refTy.dyn_cast()) + return boxTy.getEleTy(); + if (auto ptrTy = refTy.dyn_cast()) + return ptrTy.getEleTy(); + if (auto ptrTy = refTy.dyn_cast()) + return ptrTy.getEleTy(); + if (auto ptrTy = refTy.dyn_cast()) + return ptrTy.getEleTy(); + llvm_unreachable("not a reference type"); + } + + /// return true if all `Value`s in `operands` are not `FieldIndexOp`s + static bool noFieldIndexOps(mlir::Operation::operand_range operands) { + for (auto opnd : operands) { + if (auto defop = opnd.getDefiningOp()) + if (dyn_cast(defop)) + return false; + } + return true; + } + + SmallVector arguments(OperandTy vec, unsigned s, + unsigned e) const { + return {vec.begin() + s, vec.begin() + e}; + } + + int64_t getIntValue(mlir::Value val) const { + if (val) + if (auto defop = val.getDefiningOp()) + if (auto constOp = dyn_cast(defop)) + return constOp.getValue(); + llvm_unreachable("must be a constant"); + } +}; + +/// convert a field index to a runtime function that computes the byte offset +/// of the dynamic field +struct FieldIndexOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + // NB: most field references should be resolved by this point + mlir::LogicalResult + matchAndRewrite(fir::FieldIndexOp field, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // call the compiler generated function to determine the byte offset of + // the field at runtime + auto symAttr = + mlir::SymbolRefAttr::get(methodName(field), field.getContext()); + SmallVector attrs{ + rewriter.getNamedAttr("callee", symAttr)}; + auto ty = lowerTy().offsetType(); + rewriter.replaceOpWithNewOp(field, ty, operands, attrs); + return success(); + } + + // constructing the name of the method + inline static std::string methodName(fir::FieldIndexOp field) { + auto fldName = field.field_id(); + auto type = field.on_type().cast(); + // note: using std::string to dodge a bug in g++ 7.4.0 + std::string tyName = type.getName().str(); + Twine methodName = "_QQOFFSETOF_" + tyName + "_" + fldName; + return methodName.str(); + } +}; + +struct LenParamIndexOpConversion + : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + // FIXME: this should be specialized by the runtime target + mlir::LogicalResult + matchAndRewrite(fir::LenParamIndexOp lenp, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ity = lowerTy().indexType(); + auto onty = lenp.getOnType(); + // size of portable descriptor + const unsigned boxsize = 24; // FIXME + unsigned offset = boxsize; + // add the size of the rows of triples + if (auto arr = onty.dyn_cast()) { + offset += 3 * arr.getDimension(); + } + // advance over some addendum fields + const unsigned addendumOffset{sizeof(void *) + sizeof(uint64_t)}; + offset += addendumOffset; + // add the offset into the LENs + offset += 0; // FIXME + auto attr = rewriter.getI64IntegerAttr(offset); + rewriter.replaceOpWithNewOp(lenp, ity, attr); + return success(); + } +}; + +/// lower the fir.end operation to a null (erasing it) +struct FirEndOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::FirEndOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, {}); + return success(); + } +}; + +/// lower a type descriptor to a global constant +struct GenTypeDescOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::GenTypeDescOp gentypedesc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = gentypedesc.getLoc(); + auto inTy = gentypedesc.getInType(); + auto name = consName(rewriter, inTy); + auto gty = unwrap(convertType(inTy)); + auto pty = mlir::LLVM::LLVMPointerType::get(gty); + auto module = gentypedesc->getParentOfType(); + createGlobal(loc, module, name, gty, rewriter); + rewriter.replaceOpWithNewOp(gentypedesc, pty, + name); + return success(); + } + + std::string consName(mlir::ConversionPatternRewriter &rewriter, + mlir::Type type) const { + if (auto d = type.dyn_cast()) { + auto name = d.getName(); + auto pair = fir::NameUniquer::deconstruct(name); + return lowerTy().getUniquer().doTypeDescriptor( + pair.second.modules, pair.second.host, pair.second.name, + pair.second.kinds); + } + llvm_unreachable("no name found"); + } +}; + +struct GlobalLenOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::GlobalLenOp globalLen, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + TODO(""); + return success(); + } +}; + +struct HasValueOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::HasValueOp op, OperandTy operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; + +struct GlobalOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::GlobalOp global, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto tyAttr = unwrap(convertType(global.getType())); + if (global.getType().isa()) + tyAttr = tyAttr.cast().getElementType(); + auto loc = global.getLoc(); + mlir::Attribute initAttr{}; + if (global.initVal()) + initAttr = global.initVal().getValue(); + auto linkage = convertLinkage(global.linkName()); + auto isConst = global.constant().hasValue(); + auto g = rewriter.create( + loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); + auto &gr = g.getInitializerRegion(); + rewriter.inlineRegionBefore(global.region(), gr, gr.end()); + rewriter.eraseOp(global); + return success(); + } + + mlir::LLVM::Linkage convertLinkage(Optional optLinkage) const { + if (optLinkage.hasValue()) { + auto name = optLinkage.getValue(); + if (name == "internal") + return mlir::LLVM::Linkage::Internal; + if (name == "linkonce") + return mlir::LLVM::Linkage::Linkonce; + if (name == "common") + return mlir::LLVM::Linkage::Common; + if (name == "weak") + return mlir::LLVM::Linkage::Weak; + } + return mlir::LLVM::Linkage::External; + } +}; + +// convert to LLVM IR dialect `load` +struct LoadOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::LoadOp load, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // fir.box is a special case because it is considered as an ssa values in + // fir, but it is lowered as a pointer to a descriptor. So fir.ref + // and fir.box end up being the same llvm types and loading a fir.ref + // is actually a no op in LLVM. + if (load.getType().isa()) { + rewriter.replaceOp(load, operands[0]); + } else { + auto ty = convertType(load.getType()); + auto at = load.getAttrs(); + rewriter.replaceOpWithNewOp(load, ty, operands, at); + } + return success(); + } +}; + +// FIXME: how do we want to enforce this in LLVM-IR? Can we manipulate the fast +// math flags? +struct NoReassocOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::NoReassocOp noreassoc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + noreassoc.replaceAllUsesWith(operands[0]); + rewriter.eraseOp(noreassoc); + return success(); + } +}; + +void genCondBrOp(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter, + mlir::Block *newBlock) { + if (destOps.hasValue()) + rewriter.create(loc, cmp, dest, destOps.getValue(), + newBlock, mlir::ValueRange()); + else + rewriter.create(loc, cmp, dest, newBlock); +} + +template +void genBrOp(A caseOp, mlir::Block *dest, llvm::Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + if (destOps.hasValue()) + rewriter.replaceOpWithNewOp(caseOp, destOps.getValue(), + dest); + else + rewriter.replaceOpWithNewOp(caseOp, llvm::None, dest); +} + +void genCaseLadderStep(mlir::Location loc, mlir::Value cmp, mlir::Block *dest, + Optional destOps, + mlir::ConversionPatternRewriter &rewriter) { + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock = createBlock(rewriter, dest); + rewriter.setInsertionPointToEnd(thisBlock); + genCondBrOp(loc, cmp, dest, destOps, rewriter, newBlock); + rewriter.setInsertionPointToEnd(newBlock); +} + +/// Conversion of `fir.select_case` +/// +/// TODO: lowering of CHARACTER type cases +struct SelectCaseOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectCaseOp caseOp, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + const auto conds = caseOp.getNumConditions(); + auto attrName = fir::SelectCaseOp::getCasesAttr(); + auto cases = caseOp->getAttrOfType(attrName).getValue(); + // Type can be CHARACTER, INTEGER, or LOGICAL (C1145) + LLVM_ATTRIBUTE_UNUSED auto ty = caseOp.getSelector().getType(); + auto selector = caseOp.getSelector(operands); + auto loc = caseOp.getLoc(); + assert(conds > 0 && "fir.selectcase must have cases"); + for (std::remove_const_t t = 0; t != conds; ++t) { + mlir::Block *dest = caseOp.getSuccessor(t); + auto destOps = caseOp.getSuccessorOperands(operands, t); + auto cmpOps = *caseOp.getCompareOperands(operands, t); + auto caseArg = *cmpOps.begin(); + auto &attr = cases[t]; + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, selector, caseArg); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + if (attr.isa()) { + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, caseArg, selector); + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock1 = createBlock(rewriter, dest); + auto *newBlock2 = createBlock(rewriter, dest); + rewriter.setInsertionPointToEnd(thisBlock); + rewriter.create(loc, cmp, newBlock1, newBlock2); + rewriter.setInsertionPointToEnd(newBlock1); + auto caseArg_ = *(cmpOps.begin() + 1); + auto cmp_ = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::sle, selector, caseArg_); + genCondBrOp(loc, cmp_, dest, destOps, rewriter, newBlock2); + rewriter.setInsertionPointToEnd(newBlock2); + continue; + } + assert(attr.isa()); + assert((t + 1 == conds) && "unit must be last"); + genBrOp(caseOp, dest, destOps, rewriter); + } + return success(); + } +}; + +template +void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select, + OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) { + // We could target the LLVM switch instruction, but it isn't part of the + // LLVM IR dialect. Create an if-then-else ladder instead. + auto conds = select.getNumConditions(); + auto attrName = OP::getCasesAttr(); + auto caseAttr = select->template getAttrOfType(attrName); + auto cases = caseAttr.getValue(); + auto ty = select.getSelector().getType(); + auto ity = lowering.convertType(ty); + auto selector = select.getSelector(operands); + auto loc = select.getLoc(); + assert(conds > 0 && "select must have cases"); + for (decltype(conds) t = 0; t != conds; ++t) { + mlir::Block *dest = select.getSuccessor(t); + auto destOps = select.getSuccessorOperands(operands, t); + auto &attr = cases[t]; + if (auto intAttr = attr.template dyn_cast()) { + auto ci = rewriter.create( + loc, ity, rewriter.getIntegerAttr(ty, intAttr.getInt())); + auto cmp = rewriter.create( + loc, mlir::LLVM::ICmpPredicate::eq, selector, ci); + genCaseLadderStep(loc, cmp, dest, destOps, rewriter); + continue; + } + assert(attr.template dyn_cast_or_null()); + assert((t + 1 == conds) && "unit must be last"); + genBrOp(select, dest, destOps, rewriter); + } +} + +/// conversion of fir::SelectOp to an if-then-else ladder +struct SelectOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, operands, rewriter); + return success(); + } +}; + +/// conversion of fir::SelectRankOp to an if-then-else ladder +struct SelectRankOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectRankOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + selectMatchAndRewrite(lowerTy(), op, operands, rewriter); + return success(); + } +}; + +struct SelectTypeOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SelectTypeOp select, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::emitError(select.getLoc(), + "fir.select_type should have already been converted"); + return failure(); + } +}; + +// convert to LLVM IR dialect `store` +struct StoreOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::StoreOp store, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + if (store.value().getType().isa()) { + // fir.box value is actually in memory, load it first before storing it. + auto loc = store.getLoc(); + auto boxPtrTy = unwrap(operands[0].getType()); + auto val = rewriter.create( + loc, boxPtrTy.cast().getElementType(), + operands[0]); + rewriter.replaceOpWithNewOp(store, val, operands[1]); + } else { + rewriter.replaceOpWithNewOp(store, operands[0], + operands[1]); + } + return success(); + } +}; + +// cons an extractvalue on a tuple value, returning value at element `x` +mlir::LLVM::ExtractValueOp +genExtractValueWithIndex(mlir::Location loc, mlir::Value tuple, mlir::Type ty, + mlir::ConversionPatternRewriter &rewriter, + mlir::MLIRContext *ctx, int x) { + auto cx = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(x), ctx); + auto xty = ty.cast().getBody()[x]; + return rewriter.create(loc, xty, tuple, cx); +} + +// unbox a CHARACTER box value, yielding its components +struct UnboxCharOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UnboxCharOp unboxchar, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto *ctx = unboxchar.getContext(); + auto lenTy = unwrap(convertType(unboxchar.getType(1))); + auto loc = unboxchar.getLoc(); + auto tuple = operands[0]; + auto ty = unwrap(tuple.getType()); + mlir::Value ptr = + genExtractValueWithIndex(loc, tuple, ty, rewriter, ctx, 0); + auto len1 = genExtractValueWithIndex(loc, tuple, ty, rewriter, ctx, 1); + auto len = integerCast(loc, rewriter, lenTy, len1); + unboxchar.replaceAllUsesWith(llvm::ArrayRef{ptr, len}); + rewriter.eraseOp(unboxchar); + return success(); + } +}; + +// unbox a generic box reference, yielding its components +struct UnboxOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UnboxOp unbox, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = unbox.getLoc(); + auto tuple = operands[0]; + auto ty = unwrap(tuple.getType()); + auto oty = lowerTy().offsetType(); + auto c0 = rewriter.create( + loc, oty, rewriter.getI32IntegerAttr(0)); + mlir::Value ptr = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 0); + mlir::Value len = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 1); + mlir::Value ver = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 2); + mlir::Value rank = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 3); + mlir::Value type = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 4); + mlir::Value attr = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 5); + mlir::Value xtra = genLoadWithIndex(loc, tuple, ty, rewriter, oty, c0, 6); + // FIXME: add dims, etc. + std::vector repls{ptr, len, ver, rank, type, attr, xtra}; + unbox.replaceAllUsesWith(repls); + rewriter.eraseOp(unbox); + return success(); + } + + // generate a GEP into a structure and load the element at position `x` + mlir::LLVM::LoadOp genLoadWithIndex(mlir::Location loc, mlir::Value tuple, + mlir::Type ty, + mlir::ConversionPatternRewriter &rewriter, + mlir::Type oty, mlir::LLVM::ConstantOp c0, + int x) const { + auto ax = rewriter.getI32IntegerAttr(x); + auto cx = rewriter.create(loc, oty, ax); + auto sty = ty.dyn_cast(); + assert(sty); + auto xty = sty.getBody()[x]; + auto gep = genGEP(loc, mlir::LLVM::LLVMPointerType::get(xty), rewriter, + tuple, c0, cx); + return rewriter.create(loc, xty, gep); + } +}; + +// unbox a procedure box value, yielding its components +struct UnboxProcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UnboxProcOp unboxproc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto *ctx = unboxproc.getContext(); + auto loc = unboxproc.getLoc(); + auto tuple = operands[0]; + auto ty = unwrap(tuple.getType()); + mlir::Value ptr = + genExtractValueWithIndex(loc, tuple, ty, rewriter, ctx, 0); + mlir::Value host = + genExtractValueWithIndex(loc, tuple, ty, rewriter, ctx, 1); + std::vector repls{ptr, host}; + unboxproc.replaceAllUsesWith(repls); + rewriter.eraseOp(unboxproc); + return success(); + } +}; + +// convert to LLVM IR dialect `undef` +struct UndefOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UndefOp undef, OperandTy, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + undef, convertType(undef.getType())); + return success(); + } +}; + +struct ZeroOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::ZeroOp zero, OperandTy, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(zero.getType()); + auto llTy = unwrap(ty); + if (llTy.isa()) { + rewriter.replaceOpWithNewOp(zero, ty); + } else if (llTy.isa()) { + rewriter.replaceOpWithNewOp( + zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); + } else if (mlir::LLVM::isCompatibleFloatingPointType(llTy)) { + rewriter.replaceOpWithNewOp( + zero, ty, mlir::IntegerAttr::get(zero.getType(), 0.0)); + } else { + // FIXME/TODO: how do we create a ConstantAggregateZero? + rewriter.replaceOpWithNewOp(zero, ty); + } + return success(); + } +}; + +// convert to LLVM IR dialect `unreachable` +struct UnreachableOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UnreachableOp unreach, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(unreach); + return success(); + } +}; + +// +// Primitive operations on Real (floating-point) types +// + +/// Convert a floating-point primitive +template +void lowerRealBinaryOp(BINOP binop, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter, + fir::LLVMTypeConverter &lowering) { + auto ty = lowering.convertType(binop.getType()); + rewriter.replaceOpWithNewOp(binop, ty, operands); +} + +struct AddfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AddfOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + lowerRealBinaryOp(op, operands, rewriter, lowerTy()); + return success(); + } +}; +struct SubfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SubfOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + lowerRealBinaryOp(op, operands, rewriter, lowerTy()); + return success(); + } +}; +struct MulfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::MulfOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + lowerRealBinaryOp(op, operands, rewriter, lowerTy()); + return success(); + } +}; +struct DivfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::DivfOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + lowerRealBinaryOp(op, operands, rewriter, lowerTy()); + return success(); + } +}; +struct ModfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::ModfOp op, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + lowerRealBinaryOp(op, operands, rewriter, lowerTy()); + return success(); + } +}; + +struct NegfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::NegfOp neg, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(neg.getType()); + rewriter.replaceOpWithNewOp(neg, ty, operands); + return success(); + } +}; + +// +// Primitive operations on Complex types +// + +/// Generate inline code for complex addition/subtraction +template +mlir::LLVM::InsertValueOp complexSum(OPTY sumop, OperandTy opnds, + mlir::ConversionPatternRewriter &rewriter, + fir::LLVMTypeConverter &lowering) { + auto a = opnds[0]; + auto b = opnds[1]; + auto loc = sumop.getLoc(); + auto ctx = sumop.getContext(); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto eleTy = lowering.convertType(getComplexEleTy(sumop.getType())); + auto ty = lowering.convertType(sumop.getType()); + auto x = rewriter.create(loc, eleTy, a, c0); + auto y = rewriter.create(loc, eleTy, a, c1); + auto x_ = rewriter.create(loc, eleTy, b, c0); + auto y_ = rewriter.create(loc, eleTy, b, c1); + auto rx = rewriter.create(loc, eleTy, x, x_); + auto ry = rewriter.create(loc, eleTy, y, y_); + auto r = rewriter.create(loc, ty); + auto r_ = rewriter.create(loc, ty, r, rx, c0); + return rewriter.create(loc, ty, r_, ry, c1); +} + +struct AddcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AddcOp addc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: (x + iy) * (x' + iy') + // result: (x + x') + i(y + y') + auto r = + complexSum(addc, operands, rewriter, lowerTy()); + addc.replaceAllUsesWith(r.getResult()); + rewriter.replaceOp(addc, r.getResult()); + return success(); + } +}; + +struct SubcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::SubcOp subc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: (x + iy) * (x' + iy') + // result: (x - x') + i(y - y') + auto r = + complexSum(subc, operands, rewriter, lowerTy()); + subc.replaceAllUsesWith(r.getResult()); + rewriter.replaceOp(subc, r.getResult()); + return success(); + } +}; + +/// Inlined complex multiply +struct MulcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::MulcOp mulc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // TODO: should this just call __muldc3 ? + // given: (x + iy) * (x' + iy') + // result: (xx'-yy')+i(xy'+yx') + auto a = operands[0]; + auto b = operands[1]; + auto loc = mulc.getLoc(); + auto ctx = mulc.getContext(); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto eleTy = convertType(getComplexEleTy(mulc.getType())); + auto ty = convertType(mulc.getType()); + auto x = rewriter.create(loc, eleTy, a, c0); + auto y = rewriter.create(loc, eleTy, a, c1); + auto x_ = rewriter.create(loc, eleTy, b, c0); + auto y_ = rewriter.create(loc, eleTy, b, c1); + auto xx_ = rewriter.create(loc, eleTy, x, x_); + auto yx_ = rewriter.create(loc, eleTy, y, x_); + auto xy_ = rewriter.create(loc, eleTy, x, y_); + auto ri = rewriter.create(loc, eleTy, xy_, yx_); + auto yy_ = rewriter.create(loc, eleTy, y, y_); + auto rr = rewriter.create(loc, eleTy, xx_, yy_); + auto ra = rewriter.create(loc, ty); + auto r_ = rewriter.create(loc, ty, ra, rr, c0); + auto r = rewriter.create(loc, ty, r_, ri, c1); + mulc.replaceAllUsesWith(r.getResult()); + rewriter.replaceOp(mulc, r.getResult()); + return success(); + } +}; + +/// Inlined complex division +struct DivcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + // Should this just call __divdc3? Just generate inline code for now. + mlir::LogicalResult + matchAndRewrite(fir::DivcOp divc, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: (x + iy) / (x' + iy') + // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y' + auto a = operands[0]; + auto b = operands[1]; + auto loc = divc.getLoc(); + auto ctx = divc.getContext(); + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctx); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctx); + auto eleTy = convertType(getComplexEleTy(divc.getType())); + auto ty = convertType(divc.getType()); + auto x = rewriter.create(loc, eleTy, a, c0); + auto y = rewriter.create(loc, eleTy, a, c1); + auto x_ = rewriter.create(loc, eleTy, b, c0); + auto y_ = rewriter.create(loc, eleTy, b, c1); + auto xx_ = rewriter.create(loc, eleTy, x, x_); + auto x_x_ = rewriter.create(loc, eleTy, x_, x_); + auto yx_ = rewriter.create(loc, eleTy, y, x_); + auto xy_ = rewriter.create(loc, eleTy, x, y_); + auto yy_ = rewriter.create(loc, eleTy, y, y_); + auto y_y_ = rewriter.create(loc, eleTy, y_, y_); + auto d = rewriter.create(loc, eleTy, x_x_, y_y_); + auto rrn = rewriter.create(loc, eleTy, xx_, yy_); + auto rin = rewriter.create(loc, eleTy, yx_, xy_); + auto rr = rewriter.create(loc, eleTy, rrn, d); + auto ri = rewriter.create(loc, eleTy, rin, d); + auto ra = rewriter.create(loc, ty); + auto r_ = rewriter.create(loc, ty, ra, rr, c0); + auto r = rewriter.create(loc, ty, r_, ri, c1); + divc.replaceAllUsesWith(r.getResult()); + rewriter.replaceOp(divc, r.getResult()); + return success(); + } +}; + +/// Inlined complex negation +struct NegcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::NegcOp neg, OperandTy operands, + mlir::ConversionPatternRewriter &rewriter) const override { + // given: -(x + iy) + // result: -x - iy + auto ctxt = neg.getContext(); + auto eleTy = convertType(getComplexEleTy(neg.getType())); + auto ty = convertType(neg.getType()); + auto loc = neg.getLoc(); + auto &o0 = operands[0]; + auto c0 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(0), ctxt); + auto c1 = mlir::ArrayAttr::get(rewriter.getI32IntegerAttr(1), ctxt); + auto rp = rewriter.create(loc, eleTy, o0, c0); + auto ip = rewriter.create(loc, eleTy, o0, c1); + auto nrp = rewriter.create(loc, eleTy, rp); + auto nip = rewriter.create(loc, eleTy, ip); + auto r = rewriter.create(loc, ty, o0, nrp, c0); + rewriter.replaceOpWithNewOp(neg, ty, r, nip, c1); + return success(); + } +}; + +// Lower a SELECT operation into a cascade of conditional branches. The last +// case must be the `true` condition. +/// Convert FIR dialect to LLVM dialect +/// +/// This pass lowers all FIR dialect operations to LLVM IR dialect. An +/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. +struct FIRToLLVMLoweringPass + : public mlir::PassWrapper> { + FIRToLLVMLoweringPass(fir::NameUniquer &) {} + + mlir::ModuleOp getModule() { return getOperation(); } + + void runOnOperation() override final { + if (disableFirToLLVMIR) + return; + + auto *context = getModule().getContext(); + fir::LLVMTypeConverter typeConverter{getModule()}; + auto loc = mlir::UnknownLoc::get(context); + mlir::OwningRewritePatternList pattern; + pattern.insert< + AddcOpConversion, AddfOpConversion, AddrOfOpConversion, + AllocaOpConversion, AllocMemOpConversion, BoxAddrOpConversion, + BoxCharLenOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, + BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, + BoxProcHostOpConversion, BoxRankOpConversion, BoxTypeDescOpConversion, + CallOpConversion, CmpcOpConversion, CmpfOpConversion, + ConstcOpConversion, ConstfOpConversion, ConvertOpConversion, + CoordinateOpConversion, DispatchOpConversion, DispatchTableOpConversion, + DivcOpConversion, DivfOpConversion, DTEntryOpConversion, + EmboxOpConversion, EmboxCharOpConversion, EmboxProcOpConversion, + FieldIndexOpConversion, FirEndOpConversion, ExtractValueOpConversion, + FreeMemOpConversion, GenTypeDescOpConversion, GlobalLenOpConversion, + GlobalOpConversion, HasValueOpConversion, InsertOnRangeOpConversion, + InsertValueOpConversion, LenParamIndexOpConversion, LoadOpConversion, + ModfOpConversion, MulcOpConversion, MulfOpConversion, NegcOpConversion, + NegfOpConversion, NoReassocOpConversion, SelectCaseOpConversion, + SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, + StoreOpConversion, StringLitOpConversion, SubcOpConversion, + SubfOpConversion, UnboxCharOpConversion, UnboxOpConversion, + UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion, + XArrayCoorOpConversion, XEmboxOpConversion>(context, typeConverter); + mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); + mlir::ConversionTarget target{*context}; + target.addLegalDialect(); + target.addLegalDialect(); + + // required NOPs for applying a full conversion + target.addLegalOp(); + + // apply the patterns + if (mlir::failed(mlir::applyFullConversion(getModule(), target, + std::move(pattern)))) { + mlir::emitError(loc, "error in converting to LLVM-IR dialect\n"); + signalPassFailure(); + } + } +}; + +/// Lower from LLVM IR dialect to proper LLVM-IR and dump the module +struct LLVMIRLoweringPass + : public mlir::PassWrapper> { + LLVMIRLoweringPass(raw_ostream &output) : output{output} {} + + mlir::ModuleOp getModule() { return getOperation(); } + + void runOnOperation() override final { + if (disableLLVM) + return; + + auto optName = getModule().getName(); + llvm::LLVMContext llvmCtx; + if (auto llvmModule = mlir::translateModuleToLLVMIR( + getModule(), llvmCtx, optName ? *optName : "FIRModule")) { + llvmModule->print(output, nullptr); + return; + } + + auto *ctx = getModule().getContext(); + mlir::emitError(mlir::UnknownLoc::get(ctx), "could not emit LLVM-IR\n"); + signalPassFailure(); + } + +private: + llvm::raw_ostream &output; +}; + +} // namespace + +std::unique_ptr +fir::createFIRToLLVMPass(fir::NameUniquer &nameUniquer) { + return std::make_unique(nameUniquer); +} + +std::unique_ptr +fir::createLLVMDialectToLLVMPass(llvm::raw_ostream &output) { + return std::make_unique(output); +} + +// Register the FIR to LLVM-IR pass +static mlir::PassRegistration + passLowFIR("fir-to-llvmir", + "Conversion of the FIR dialect to the LLVM-IR dialect", [] { + fir::NameUniquer dummy; + return std::make_unique(dummy); + }); Index: flang/lib/Optimizer/CodeGen/DescriptorModel.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/DescriptorModel.h @@ -0,0 +1,144 @@ +//===-- DescriptorModel.h -- model of descriptors for codegen ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_DESCRIPTOR_MODEL_H +#define OPTIMIZER_DESCRIPTOR_MODEL_H + +#include "../runtime/descriptor.h" +#include "flang/ISO_Fortran_binding.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm/Support/ErrorHandling.h" +#include + +namespace fir { + +//===----------------------------------------------------------------------===// +// LLVM IR dialect models of C++ types. +// +// This supplies a set of model builders to decompose the C declaration of a +// descriptor (as encoded in ISO_Fortran_binding.h and elsewhere) and +// reconstruct that type in the LLVM IR dialect. +// +// TODO: It is understood that this is deeply incorrect as far as building a +// portability layer for cross-compilation as these reflected types are those of +// the build machine and not necessarily that of either the host or the target. +// This assumption that build == host == target is actually pervasive across the +// compiler. +// +//===----------------------------------------------------------------------===// + +using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *); + +/// Get the LLVM IR dialect model for building a particular C++ type, `T`. +template +TypeBuilderFunc getModel(); + +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8)); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(int) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, sizeof(long long) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_rank_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_type_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + return mlir::IntegerType::get(context, + sizeof(Fortran::ISO::CFI_index_t) * 8); + }; +} +template <> +TypeBuilderFunc getModel() { + return [](mlir::MLIRContext *context) -> mlir::Type { + auto indexTy = getModel()(context); + return mlir::LLVM::LLVMArrayType::get(indexTy, 3); + }; +} +template <> +TypeBuilderFunc +getModel>() { + return getModel(); +} + +//===----------------------------------------------------------------------===// +// Descriptor reflection +//===----------------------------------------------------------------------===// + +/// Get the type model of the field number `Field` in an ISO descriptor. +template +static constexpr TypeBuilderFunc getDescFieldTypeModel() { + Fortran::ISO::Fortran_2018::CFI_cdesc_t dummyDesc{}; + // check that the descriptor is exactly 8 fields + auto [a, b, c, d, e, f, g, h] = dummyDesc; + auto tup = std::tie(a, b, c, d, e, f, g, h); + auto field = std::get(tup); + return getModel(); +} + +/// An extended descriptor is defined by a class in runtime/descriptor.h. The +/// three fields in the class are hard-coded here, unlike the reflection used on +/// the ISO parts, which are a POD. +template +static constexpr TypeBuilderFunc getExtendedDescFieldTypeModel() { + if constexpr (Field == 8) { + return getModel(); + } else if constexpr (Field == 9) { + return getModel(); + } else if constexpr (Field == 10) { + return getModel(); + } else { + llvm_unreachable("extended ISO descriptor only has 11 fields"); + } +} + +} // namespace fir + +#endif // OPTIMIZER_DESCRIPTOR_MODEL_H Index: flang/lib/Optimizer/CodeGen/PassDetail.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PassDetail.h @@ -0,0 +1,23 @@ +//===- PassDetail.h - Optimizer code gen Pass class details -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTMIZER_CODEGEN_PASSDETAIL_H +#define OPTMIZER_CODEGEN_PASSDETAIL_H + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace fir { + +#define GEN_PASS_CLASSES +#include "flang/Optimizer/CodeGen/CGPasses.h.inc" + +} // namespace fir + +#endif // OPTMIZER_CODEGEN_PASSDETAIL_H Index: flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -0,0 +1,247 @@ +//===-- PreCGRewrite.cpp --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "PassDetail.h" +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" + +//===----------------------------------------------------------------------===// +// Codegen rewrite: rewriting of subgraphs of ops +//===----------------------------------------------------------------------===// + +using namespace fir; + +#define DEBUG_TYPE "flang-codegen-rewrite" + +static void populateShape(llvm::SmallVectorImpl &vec, + ShapeOp shape) { + vec.append(shape.extents().begin(), shape.extents().end()); +} + +// Operands of fir.shape_shift split into two vectors. +static void populateShapeAndShift(llvm::SmallVectorImpl &shapeVec, + llvm::SmallVectorImpl &shiftVec, + ShapeShiftOp shift) { + auto endIter = shift.pairs().end(); + for (auto i = shift.pairs().begin(); i != endIter;) { + shiftVec.push_back(*i++); + shapeVec.push_back(*i++); + } +} + +namespace { + +/// Convert fir.embox to the extended form where necessary. +class EmboxConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(EmboxOp embox, + mlir::PatternRewriter &rewriter) const override { + auto shapeVal = embox.getShape(); + // If the embox does not include a shape, then do not convert it + if (shapeVal) + return rewriteDynamicShape(embox, rewriter, shapeVal); + if (auto boxTy = embox.getType().dyn_cast()) + if (auto seqTy = boxTy.getEleTy().dyn_cast()) + if (seqTy.hasConstantShape()) + return rewriteStaticShape(embox, rewriter, seqTy); + return mlir::failure(); + } + + mlir::LogicalResult rewriteStaticShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + SequenceType seqTy) const { + auto loc = embox.getLoc(); + llvm::SmallVector shapeOpers; + auto idxTy = rewriter.getIndexType(); + for (auto ext : seqTy.getShape()) { + auto iAttr = rewriter.getIndexAttr(ext); + auto extVal = rewriter.create(loc, idxTy, iAttr); + shapeOpers.push_back(extVal); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, + llvm::None, llvm::None, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } + + mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + mlir::Value shapeVal) const { + auto loc = embox.getLoc(); + auto shapeOp = dyn_cast(shapeVal.getDefiningOp()); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (shapeOp) { + populateShape(shapeOpers, shapeOp); + } else { + auto shiftOp = dyn_cast(shapeVal.getDefiningOp()); + assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = embox.getSlice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert all fir.array_coor to the extended form. +class ArrayCoorConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayCoorOp arrCoor, + mlir::PatternRewriter &rewriter) const override { + auto loc = arrCoor.getLoc(); + auto shapeVal = arrCoor.shape(); + auto shapeOp = dyn_cast(shapeVal.getDefiningOp()); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (shapeOp) { + populateShape(shapeOpers, shapeOp); + } else if (auto shiftOp = + dyn_cast(shapeVal.getDefiningOp())) { + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + } else { + return mlir::failure(); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = arrCoor.slice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xArrCoor = rewriter.create( + loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams()); + LLVM_DEBUG(llvm::dbgs() + << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); + rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert FIR structured control flow ops to CFG ops. +class CodeGenRewrite : public CodeGenRewriteBase { +public: + void runOn(mlir::Operation *op, mlir::Region ®ion) { + auto &context = getContext(); + mlir::OpBuilder rewriter(&context); + mlir::ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](EmboxOp embox) { + return !(embox.getShape() || + embox.getType().cast().getEleTy().isa()); + }); + mlir::OwningRewritePatternList patterns; + patterns.insert(&context); + if (mlir::failed( + mlir::applyPartialConversion(op, target, std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(&context), + "error in running the pre-codegen conversions"); + signalPassFailure(); + } + // Erase any residual. + simplifyRegion(region); + } + + void runOnOperation() override final { + // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp. + auto mod = getOperation(); + for (auto func : mod.getOps()) + runOn(func, func.getBody()); + for (auto global : mod.getOps()) + runOn(global, global.getRegion()); + } + + // Clean up the region. + void simplifyRegion(mlir::Region ®ion) { + for (auto &block : region.getBlocks()) + for (auto &op : block.getOperations()) { + if (op.getNumRegions() != 0) + for (auto ® : op.getRegions()) + simplifyRegion(reg); + maybeEraseOp(&op); + } + + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + } + + void maybeEraseOp(mlir::Operation *op) { + if (!op) + return; + + // Erase any embox that was replaced. + if (auto embox = dyn_cast(op)) + if (embox.getShape()) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + + // Erase all fir.array_coor. + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + + // Erase all fir.shape, fir.shape_shift, and fir.slice ops. + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + } + +private: + std::vector opsToErase; +}; + +} // namespace + +/// Convert FIR's structured control flow ops to CFG ops. This conversion +/// enables the `createLowerToCFGPass` to transform these to CFG form. +std::unique_ptr fir::createFirCodeGenRewritePass() { + return std::make_unique(); +} Index: flang/lib/Optimizer/CodeGen/Target.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/Target.h @@ -0,0 +1,108 @@ +//===- Target.h - target specific details -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef OPTMIZER_CODEGEN_TARGET_H +#define OPTMIZER_CODEGEN_TARGET_H + +#include "mlir/IR/BuiltinTypes.h" +#include +#include +#include + +namespace llvm { +class Triple; +} // namespace llvm + +namespace fir { +class KindMapping; + +namespace details { +/// Extra information about how to marshal an argument or return value that +/// modifies a signature per a particular ABI's calling convention. +/// Note: llvm::Attribute is not used directly, because its use depends on an +/// LLVMContext. +class Attributes { +public: + Attributes() : alignment{0}, byval{false}, sret{false}, append{false} {} + Attributes(unsigned short alignment, bool byval = false, bool sret = false, + bool append = false) + : alignment{alignment}, byval{byval}, sret{sret}, append{append} {} + + unsigned getAlignment() const { return alignment; } + bool hasAlignment() const { return alignment != 0; } + bool isByVal() const { return byval; } + bool returnValueAsArgument() const { return isSRet(); } + bool isSRet() const { return sret; } + bool isAppend() const { return append; } + +private: + unsigned short alignment{}; + bool byval : 1; + bool sret : 1; + bool append : 1; +}; + +} // namespace details + +/// Some details of how to represent certain features depend on the target and +/// ABI that is being used. These specifics are captured here and guide the +/// lowering of FIR to LLVM-IR dialect. +class CodeGenSpecifics { +public: + using Attributes = details::Attributes; + using Marshalling = std::vector>; + + static std::unique_ptr + get(mlir::MLIRContext *ctx, llvm::Triple &trp, KindMapping &kindMap); + + CodeGenSpecifics(mlir::MLIRContext *ctx, llvm::Triple &trp, + KindMapping &kindMap) + : context{*ctx}, triple{trp}, kindMap{kindMap} {} + CodeGenSpecifics() = delete; + virtual ~CodeGenSpecifics() {} + + /// Type presentation of a `complex` type value in memory. + virtual mlir::Type complexMemoryType(mlir::Type eleTy) const = 0; + + /// Type presentation of a `complex` type argument when passed by + /// value. An argument value may need to be passed as a (safe) reference + /// argument. + virtual Marshalling complexArgumentType(mlir::Type eleTy) const = 0; + + /// Type presentation of a `complex` type return value. Such a return + /// value may need to be converted to a hidden reference argument. + virtual Marshalling complexReturnType(mlir::Type eleTy) const = 0; + + /// Type presentation of a `boxchar` type value in memory. + virtual mlir::Type boxcharMemoryType(mlir::Type eleTy) const = 0; + + /// Type presentation of a `boxchar` type argument when passed by value. An + /// argument value may need to be passed as a (safe) reference argument. + /// + /// A function that returns a `boxchar` type value must already have + /// converted that return value to an sret argument. This requirement is in + /// keeping with Fortran semantics, which require the caller to allocate the + /// space for the return CHARACTER value and pass a pointer and the length of + /// that space (a boxchar) to the called function. Such functions should be + /// annotated with an Attribute to distinguish the sret argument. + virtual Marshalling boxcharArgumentType(mlir::Type eleTy, + bool sret = false) const = 0; + +protected: + mlir::MLIRContext &context; + llvm::Triple &triple; + KindMapping &kindMap; +}; + +} // namespace fir + +#endif // OPTMIZER_CODEGEN_TARGET_H Index: flang/lib/Optimizer/CodeGen/Target.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/Target.cpp @@ -0,0 +1,247 @@ +//===-- Target.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "Target.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/KindMapping.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeRange.h" +#include "llvm/ADT/Triple.h" + +#define DEBUG_TYPE "flang-codegen-target" + +using namespace fir; + +// Reduce a REAL/float type to the floating point semantics. +static const llvm::fltSemantics &floatToSemantics(KindMapping &kindMap, + mlir::Type type) { + assert(isa_real(type)); + if (auto ty = type.dyn_cast()) + return kindMap.getFloatSemantics(ty.getFKind()); + return type.cast().getFloatSemantics(); +} + +namespace { +template +struct GenericTarget : public CodeGenSpecifics { + using CodeGenSpecifics::CodeGenSpecifics; + using AT = CodeGenSpecifics::Attributes; + + mlir::Type complexMemoryType(mlir::Type eleTy) const override { + assert(fir::isa_real(eleTy)); + // { t, t } struct of 2 eleTy + mlir::TypeRange range = {eleTy, eleTy}; + return mlir::TupleType::get(eleTy.getContext(), range); + } + + mlir::Type boxcharMemoryType(mlir::Type eleTy) const override { + auto idxTy = mlir::IntegerType::get(eleTy.getContext(), S::defaultWidth ); + auto ptrTy = fir::ReferenceType::get(eleTy); + // { t*, index } + mlir::TypeRange range = {ptrTy, idxTy}; + return mlir::TupleType::get(eleTy.getContext(), range); + } + + Marshalling boxcharArgumentType(mlir::Type eleTy, bool sret) const override { + CodeGenSpecifics::Marshalling marshal; + auto idxTy = mlir::IntegerType::get(eleTy.getContext(), S::defaultWidth); + auto ptrTy = fir::ReferenceType::get(eleTy); + marshal.emplace_back(ptrTy, AT{}); + // Return value arguments are grouped as a pair. Others are passed in a + // split format with all pointers first (in the declared position) and all + // LEN arguments appended after all of the dummy arguments. + // NB: Other conventions/ABIs can/should be supported via options. + marshal.emplace_back(idxTy, AT{0, {}, {}, /*append=*/!sret}); + return marshal; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// i386 (x86 32 bit) linux target specifics. +//===----------------------------------------------------------------------===// + +namespace { +struct TargetI386 : public GenericTarget { + using GenericTarget::GenericTarget; + + static constexpr int defaultWidth = 32; + + CodeGenSpecifics::Marshalling + complexArgumentType(mlir::Type eleTy) const override { + assert(fir::isa_real(eleTy)); + CodeGenSpecifics::Marshalling marshal; + // { t, t } struct of 2 eleTy, byval, align 4 + mlir::TypeRange range = {eleTy, eleTy}; + auto structTy = mlir::TupleType::get(eleTy.getContext(), range); + marshal.emplace_back(fir::ReferenceType::get(structTy), + AT{4, /*byval=*/true, {}}); + return marshal; + } + + CodeGenSpecifics::Marshalling + complexReturnType(mlir::Type eleTy) const override { + assert(fir::isa_real(eleTy)); + CodeGenSpecifics::Marshalling marshal; + const auto *sem = &floatToSemantics(kindMap, eleTy); + if (sem == &llvm::APFloat::IEEEsingle()) { + // i64 pack both floats in a 64-bit GPR + marshal.emplace_back(mlir::IntegerType::get(eleTy.getContext(), 64), + AT{}); + } else if (sem == &llvm::APFloat::IEEEdouble()) { + // { t, t } struct of 2 eleTy, sret, align 4 + mlir::TypeRange range = {eleTy, eleTy}; + auto structTy = mlir::TupleType::get(eleTy.getContext(), range); + marshal.emplace_back(fir::ReferenceType::get(structTy), + AT{4, {}, /*sret=*/true}); + } else { + llvm_unreachable("not implemented"); + } + return marshal; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// x86_64 (x86 64 bit) linux target specifics. +//===----------------------------------------------------------------------===// + +namespace { +struct TargetX86_64 : public GenericTarget { + using GenericTarget::GenericTarget; + + static constexpr int defaultWidth = 64; + + CodeGenSpecifics::Marshalling + complexArgumentType(mlir::Type eleTy) const override { + CodeGenSpecifics::Marshalling marshal; + const auto *sem = &floatToSemantics(kindMap, eleTy); + if (sem == &llvm::APFloat::IEEEsingle()) { + // <2 x t> vector of 2 eleTy + marshal.emplace_back(fir::VectorType::get(2, eleTy), AT{}); + } else if (sem == &llvm::APFloat::IEEEdouble()) { + // two distinct double arguments + marshal.emplace_back(eleTy, AT{}); + marshal.emplace_back(eleTy, AT{}); + } else { + llvm_unreachable("not implemented"); + } + return marshal; + } + + CodeGenSpecifics::Marshalling + complexReturnType(mlir::Type eleTy) const override { + CodeGenSpecifics::Marshalling marshal; + const auto *sem = &floatToSemantics(kindMap, eleTy); + if (sem == &llvm::APFloat::IEEEsingle()) { + // <2 x t> vector of 2 eleTy + marshal.emplace_back(fir::VectorType::get(2, eleTy), AT{}); + } else if (sem == &llvm::APFloat::IEEEdouble()) { + // { double, double } struct of 2 double + mlir::TypeRange range = {eleTy, eleTy}; + marshal.emplace_back(mlir::TupleType::get(eleTy.getContext(), range), + AT{}); + } else { + llvm_unreachable("not implemented"); + } + return marshal; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// AArch64 (AArch64 bit) linux target specifics. +//===----------------------------------------------------------------------===// + +namespace { +struct TargetAArch64 : public GenericTarget { + using GenericTarget::GenericTarget; + + static constexpr int defaultWidth = 64; + + CodeGenSpecifics::Marshalling + complexArgumentType(mlir::Type eleTy) const override { + CodeGenSpecifics::Marshalling marshal; + const auto *sem = &floatToSemantics(kindMap, eleTy); + if (sem == &llvm::APFloat::IEEEsingle()) { + // <2 x t> vector of 2 eleTy + marshal.emplace_back(fir::VectorType::get(2, eleTy), AT{}); + } else if (sem == &llvm::APFloat::IEEEdouble()) { + // two distinct double arguments + marshal.emplace_back(eleTy, AT{}); + marshal.emplace_back(eleTy, AT{}); + } else { + llvm_unreachable("not implemented"); + } + return marshal; + } + + CodeGenSpecifics::Marshalling + complexReturnType(mlir::Type eleTy) const override { + CodeGenSpecifics::Marshalling marshal; + const auto *sem = &floatToSemantics(kindMap, eleTy); + if (sem == &llvm::APFloat::IEEEsingle()) { + // <2 x t> vector of 2 eleTy + marshal.emplace_back(fir::VectorType::get(2, eleTy), AT{}); + } else if (sem == &llvm::APFloat::IEEEdouble()) { + // { double, double } struct of 2 double + mlir::TypeRange range = {eleTy, eleTy}; + marshal.emplace_back(mlir::TupleType::get(eleTy.getContext(), range), + AT{}); + } else { + llvm_unreachable("not implemented"); + } + return marshal; + } +}; +} // namespace + +// Instantiate the overloaded target instance based on the triple value. +// Currently, the implementation only instantiates `i386-unknown-linux-gnu` and +// `x86_64-unknown-linux-gnu` like triples. Other targets should be added to +// this file as needed. +std::unique_ptr +fir::CodeGenSpecifics::get(mlir::MLIRContext *ctx, llvm::Triple &trp, + KindMapping &kindMap) { + switch (trp.getArch()) { + default: + break; + case llvm::Triple::ArchType::x86: + switch (trp.getOS()) { + default: + break; + case llvm::Triple::OSType::Linux: + case llvm::Triple::OSType::Darwin: + return std::make_unique(ctx, trp, kindMap); + } + break; + case llvm::Triple::ArchType::x86_64: + switch (trp.getOS()) { + default: + break; + case llvm::Triple::OSType::Linux: + case llvm::Triple::OSType::Darwin: + return std::make_unique(ctx, trp, kindMap); + } + break; + case llvm::Triple::ArchType::aarch64: + switch (trp.getOS()) { + default: + break; + case llvm::Triple::OSType::Linux: + case llvm::Triple::OSType::Darwin: + return std::make_unique(ctx, trp, kindMap); + } + break; + } + llvm::report_fatal_error("target not implemented"); +} Index: flang/lib/Optimizer/CodeGen/TargetRewrite.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -0,0 +1,727 @@ +//===-- TargetRewrite.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "Target.h" +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +//===----------------------------------------------------------------------===// +// Target rewrite: rewriting of ops to make target-specific lowerings manifest. +// LLVM expects different lowering idioms to be used for distinct target +// triples. These distinctions are handled by this pass. +//===----------------------------------------------------------------------===// + +using namespace fir; + +#define DEBUG_TYPE "flang-target-rewrite" + +namespace { + +/// Fixups for updating a FuncOp's arguments and return values. +struct FixupTy { + // clang-format off + enum class Codes { + ArgumentAsLoad, ArgumentType, CharPair, ReturnAsStore, ReturnType, + Split, Trailing + }; + // clang-format on + + FixupTy(Codes code, std::size_t index, std::size_t second = 0) + : code{code}, index{index}, second{second} {} + FixupTy(Codes code, std::size_t index, + std::function &&finalizer) + : code{code}, index{index}, finalizer{finalizer} {} + FixupTy(Codes code, std::size_t index, std::size_t second, + std::function &&finalizer) + : code{code}, index{index}, second{second}, finalizer{finalizer} {} + + Codes code; + std::size_t index; + std::size_t second{}; + llvm::Optional> finalizer{}; +}; // namespace + +/// Target-specific rewriting of the IR. This is a prerequisite pass to code +/// generation that traverses the IR and modifies types and operations to a +/// form that appropriate for the specific target. LLVM IR has specific idioms +/// that are used for distinct target processor and ABI combinations. +class TargetRewrite : public TargetRewriteBase { +public: + TargetRewrite(const TargetRewriteOptions &options) { + noCharacterConversion = options.noCharacterConversion; + noComplexConversion = options.noComplexConversion; + } + + void runOnOperation() override final { + auto &context = getContext(); + mlir::OpBuilder rewriter(&context); + auto mod = getModule(); + auto specifics = CodeGenSpecifics::get(getOperation().getContext(), + *getTargetTriple(getOperation()), + *getKindMapping(getOperation())); + setMembers(specifics.get(), &rewriter); + + // Perform type conversion on signatures and call sites. + if (mlir::failed(convertTypes(mod))) { + mlir::emitError(mlir::UnknownLoc::get(&context), + "error in converting types to target abi"); + signalPassFailure(); + } + + // Convert ops in target-specific patterns. + mod.walk([&](mlir::Operation *op) { + if (auto call = dyn_cast(op)) { + if (!hasPortableSignature(call.getFunctionType())) + convertCallOp(call); + } else if (auto dispatch = dyn_cast(op)) { + if (!hasPortableSignature(dispatch.getFunctionType())) + convertCallOp(dispatch); + } else if (auto addr = dyn_cast(op)) { + if (addr.getType().isa() && + !hasPortableSignature(addr.getType())) + convertAddrOp(addr); + } + }); + + clearMembers(); + } + + mlir::ModuleOp getModule() { return getOperation(); } + + template + std::function + rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { + auto m = specifics->complexReturnType(ty.getElementType()); + // Currently targets mandate COMPLEX is a single aggregate or packed + // scalar, including the sret case. + assert(m.size() == 1 && "target lowering of complex return not supported"); + auto resTy = std::get(m[0]); + auto attr = std::get(m[0]); + auto loc = mlir::UnknownLoc::get(resTy.getContext()); + if (attr.isSRet()) { + assert(isa_ref_type(resTy)); + mlir::Value stack = + rewriter->create(loc, dyn_cast_ptrEleTy(resTy)); + newInTys.push_back(resTy); + newOpers.push_back(stack); + return [=](mlir::Operation *) -> mlir::Value { + auto memTy = ReferenceType::get(ty); + auto cast = rewriter->create(loc, memTy, stack); + return rewriter->create(loc, cast); + }; + } + newResTys.push_back(resTy); + return [=](mlir::Operation *call) -> mlir::Value { + auto mem = rewriter->create(loc, resTy); + rewriter->create(loc, call->getResult(0), mem); + auto memTy = ReferenceType::get(ty); + auto cast = rewriter->create(loc, memTy, mem); + return rewriter->create(loc, cast); + }; + } + + template + void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, + C &newOpers) { + auto m = specifics->complexArgumentType(ty.getElementType()); + auto *ctx = ty.getContext(); + auto loc = mlir::UnknownLoc::get(ctx); + if (m.size() == 1) { + // COMPLEX is a single aggregate + auto resTy = std::get(m[0]); + auto attr = std::get(m[0]); + auto oldRefTy = ReferenceType::get(ty); + if (attr.isByVal()) { + auto mem = rewriter->create(loc, ty); + rewriter->create(loc, oper, mem); + newOpers.push_back(rewriter->create(loc, resTy, mem)); + } else { + auto mem = rewriter->create(loc, resTy); + auto cast = rewriter->create(loc, oldRefTy, mem); + rewriter->create(loc, oper, cast); + newOpers.push_back(rewriter->create(loc, mem)); + } + newInTys.push_back(resTy); + } else { + assert(m.size() == 2); + // COMPLEX is split into 2 separate arguments + auto iTy = rewriter->getIntegerType(32); + for (auto e : llvm::enumerate(m)) { + auto &tup = e.value(); + auto ty = std::get(tup); + auto index = e.index(); + mlir::Value idx = rewriter->create( + loc, iTy, mlir::IntegerAttr::get(iTy, index)); + auto val = rewriter->create(loc, ty, oper, idx); + newInTys.push_back(ty); + newOpers.push_back(val); + } + } + } + + // Convert fir.call and fir.dispatch Ops. + template + void convertCallOp(A callOp) { + auto fnTy = callOp.getFunctionType(); + auto loc = callOp.getLoc(); + rewriter->setInsertionPoint(callOp); + llvm::SmallVector newResTys; + llvm::SmallVector newInTys; + llvm::SmallVector newOpers; + + // If the call is indirect, the first argument must still be the function + // to call. + int dropFront = 0; + if constexpr (std::is_same_v, fir::CallOp>) { + if (!callOp.callee().hasValue()) { + newInTys.push_back(fnTy.getInput(0)); + newOpers.push_back(callOp.getOperand(0)); + dropFront = 1; + } + } + + // Determine the rewrite function, `wrap`, for the result value. + llvm::Optional> wrap; + if (fnTy.getResults().size() == 1) { + mlir::Type ty = fnTy.getResult(0); + llvm::TypeSwitch(ty) + .template Case([&](fir::ComplexType cmplx) { + wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, + newOpers); + }) + .template Case([&](mlir::ComplexType cmplx) { + wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, + newOpers); + }) + .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); + } else if (fnTy.getResults().size() > 1) { + // If the function is returning more than 1 result, do not perform any + // target-specific lowering. (FIXME?) This may need to be revisited. + newResTys.insert(newResTys.end(), fnTy.getResults().begin(), + fnTy.getResults().end()); + } + + llvm::SmallVector trailingInTys; + llvm::SmallVector trailingOpers; + for (auto e : llvm::enumerate( + llvm::zip(fnTy.getInputs().drop_front(dropFront), + callOp.getOperands().drop_front(dropFront)))) { + mlir::Type ty = std::get<0>(e.value()); + mlir::Value oper = std::get<1>(e.value()); + unsigned index = e.index(); + llvm::TypeSwitch(ty) + .template Case([&](BoxCharType boxTy) { + bool sret; + if constexpr (std::is_same_v, fir::CallOp>) { + sret = callOp.callee() && + functionArgIsSRet(index, + getModule().lookupSymbol( + *callOp.callee())); + } else { + // TODO: dispatch case; how do we put arguments on a call? + // We cannot put both an sret and the dispatch object first. + sret = false; + llvm_unreachable("not implemented"); + } + auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); + auto unbox = + rewriter->create(loc, std::get(m[0]), + std::get(m[1]), oper); + // unboxed CHARACTER arguments + for (auto e : llvm::enumerate(m)) { + unsigned idx = e.index(); + auto attr = std::get(e.value()); + auto argTy = std::get(e.value()); + if (attr.isAppend()) { + trailingInTys.push_back(argTy); + trailingOpers.push_back(unbox.getResult(idx)); + } else { + newInTys.push_back(argTy); + newOpers.push_back(unbox.getResult(idx)); + } + } + }) + .template Case([&](fir::ComplexType cmplx) { + rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); + }) + .template Case([&](mlir::ComplexType cmplx) { + rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); + }) + .Default([&](mlir::Type ty) { + newInTys.push_back(ty); + newOpers.push_back(oper); + }); + } + newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); + newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); + if constexpr (std::is_same_v, fir::CallOp>) { + fir::CallOp newCall; + if (callOp.callee().hasValue()) { + newCall = rewriter->create(loc, callOp.callee().getValue(), + newResTys, newOpers); + } else { + // Force new type on the input operand. + newOpers[0].setType(mlir::FunctionType::get( + callOp.getContext(), + mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); + newCall = rewriter->create(loc, newResTys, newOpers); + } + LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); + if (wrap.hasValue()) + replaceOp(callOp, (*wrap)(newCall.getOperation())); + else + replaceOp(callOp, newCall.getResults()); + } else { + // A is fir::DispatchOp + llvm_unreachable("not implemented"); // TODO + } + } + + // Result type fixup for fir::ComplexType and mlir::ComplexType + template + void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { + if (noComplexConversion) { + newResTys.push_back(cmplx); + } else { + for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { + auto argTy = std::get(tup); + if (std::get(tup).isSRet()) + newInTys.push_back(argTy); + else + newResTys.push_back(argTy); + } + } + } + + // Argument type fixup for fir::ComplexType and mlir::ComplexType + template + void lowerComplexSignatureArg(A cmplx, B &newInTys) { + if (noComplexConversion) + newInTys.push_back(cmplx); + else + for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) + newInTys.push_back(std::get(tup)); + } + + /// Taking the address of a function. Modify the signature as needed. + void convertAddrOp(AddrOfOp addrOp) { + rewriter->setInsertionPoint(addrOp); + auto addrTy = addrOp.getType().cast(); + llvm::SmallVector newResTys; + llvm::SmallVector newInTys; + for (mlir::Type ty : addrTy.getResults()) { + llvm::TypeSwitch(ty) + .Case([&](fir::ComplexType ty) { + lowerComplexSignatureRes(ty, newResTys, newInTys); + }) + .Case([&](mlir::ComplexType ty) { + lowerComplexSignatureRes(ty, newResTys, newInTys); + }) + .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); + } + llvm::SmallVector trailingInTys; + for (mlir::Type ty : addrTy.getInputs()) { + llvm::TypeSwitch(ty) + .Case([&](BoxCharType box) { + if (noCharacterConversion) { + newInTys.push_back(box); + } else { + for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) { + auto attr = std::get(tup); + auto argTy = std::get(tup); + auto &vec = attr.isAppend() ? trailingInTys : newInTys; + vec.push_back(argTy); + } + } + }) + .Case([&](fir::ComplexType ty) { + lowerComplexSignatureArg(ty, newInTys); + }) + .Case([&](mlir::ComplexType ty) { + lowerComplexSignatureArg(ty, newInTys); + }) + .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); + } + // append trailing input types + newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); + // replace this op with a new one with the updated signature + auto newTy = rewriter->getFunctionType(newInTys, newResTys); + auto newOp = + rewriter->create(addrOp.getLoc(), newTy, addrOp.symbol()); + LLVM_DEBUG(llvm::dbgs() + << "replacing " << addrOp << " with " << newOp << '\n'); + replaceOp(addrOp, newOp.getResult()); + } + + /// Convert the type signatures on all the functions present in the module. + /// As the type signature is being changed, this must also update the + /// function itself to use any new arguments, etc. + mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { + for (auto fn : mod.getOps()) + convertSignature(fn); + return mlir::success(); + } + + /// If the signature does not need any special target-specific converions, + /// then it is considered portable for any target, and this function will + /// return `true`. Otherwise, the signature is not portable and `false` is + /// returned. + bool hasPortableSignature(mlir::Type signature) { + assert(signature.isa()); + auto func = signature.dyn_cast(); + for (auto ty : func.getResults()) + if ((ty.isa() && !noCharacterConversion) || + (isa_complex(ty) && !noComplexConversion)) { + LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); + return false; + } + for (auto ty : func.getInputs()) + if ((ty.isa() && !noCharacterConversion) || + (isa_complex(ty) && !noComplexConversion)) { + LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); + return false; + } + return true; + } + + /// Rewrite the signatures and body of the `FuncOp`s in the module for + /// the immediately subsequent target code gen. + void convertSignature(mlir::FuncOp func) { + auto funcTy = func.getType().cast(); + if (hasPortableSignature(funcTy)) + return; + llvm::SmallVector newResTys; + llvm::SmallVector newInTys; + llvm::SmallVector fixups; + + // Convert return value(s) + for (auto ty : funcTy.getResults()) + llvm::TypeSwitch(ty) + .Case([&](fir::ComplexType cmplx) { + if (noComplexConversion) + newResTys.push_back(cmplx); + else + doComplexReturn(func, cmplx, newResTys, newInTys, fixups); + }) + .Case([&](mlir::ComplexType cmplx) { + if (noComplexConversion) + newResTys.push_back(cmplx); + else + doComplexReturn(func, cmplx, newResTys, newInTys, fixups); + }) + .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); + + // Convert arguments + llvm::SmallVector trailingTys; + for (auto e : llvm::enumerate(funcTy.getInputs())) { + auto ty = e.value(); + unsigned index = e.index(); + llvm::TypeSwitch(ty) + .Case([&](BoxCharType boxTy) { + if (noCharacterConversion) { + newInTys.push_back(boxTy); + } else { + // Convert a CHARACTER argument type. This can involve separating + // the pointer and the LEN into two arguments and moving the LEN + // argument to the end of the arg list. + bool sret = functionArgIsSRet(index, func); + for (auto e : llvm::enumerate(specifics->boxcharArgumentType( + boxTy.getEleTy(), sret))) { + auto &tup = e.value(); + auto index = e.index(); + auto attr = std::get(tup); + auto argTy = std::get(tup); + if (attr.isAppend()) { + trailingTys.push_back(argTy); + } else { + if (sret) { + fixups.emplace_back(FixupTy::Codes::CharPair, + newInTys.size(), index); + } else { + fixups.emplace_back(FixupTy::Codes::Trailing, + newInTys.size(), trailingTys.size()); + } + newInTys.push_back(argTy); + } + } + } + }) + .Case([&](fir::ComplexType cmplx) { + if (noComplexConversion) + newInTys.push_back(cmplx); + else + doComplexArg(func, cmplx, newInTys, fixups); + }) + .Case([&](mlir::ComplexType cmplx) { + if (noComplexConversion) + newInTys.push_back(cmplx); + else + doComplexArg(func, cmplx, newInTys, fixups); + }) + .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); + } + + if (!func.empty()) { + // If the function has a body, then apply the fixups to the arguments and + // return ops as required. These fixups are done in place. + auto loc = func.getLoc(); + const auto fixupSize = fixups.size(); + const auto oldArgTys = func.getType().getInputs(); + int offset = 0; + for (std::remove_const_t i = 0; i < fixupSize; ++i) { + const auto &fixup = fixups[i]; + switch (fixup.code) { + case FixupTy::Codes::ArgumentAsLoad: { + // Argument was pass-by-value, but is now pass-by-reference and + // possibly with a different element type. + auto newArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + rewriter->setInsertionPointToStart(&func.front()); + auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); + auto cast = rewriter->create(loc, oldArgTy, newArg); + auto load = rewriter->create(loc, cast); + func.getArgument(fixup.index + 1).replaceAllUsesWith(load); + func.front().eraseArgument(fixup.index + 1); + } break; + case FixupTy::Codes::ArgumentType: { + // Argument is pass-by-value, but its type is likely been modified to + // suit the target ABI convention. + auto newArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + rewriter->setInsertionPointToStart(&func.front()); + auto mem = + rewriter->create(loc, newInTys[fixup.index]); + rewriter->create(loc, newArg, mem); + auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); + auto cast = rewriter->create(loc, oldArgTy, mem); + mlir::Value load = rewriter->create(loc, cast); + func.getArgument(fixup.index + 1).replaceAllUsesWith(load); + func.front().eraseArgument(fixup.index + 1); + LLVM_DEBUG(llvm::dbgs() + << "old argument: " << oldArgTy.getEleTy() + << ", repl: " << load << ", new argument: " + << func.getArgument(fixup.index).getType() << '\n'); + } break; + case FixupTy::Codes::CharPair: { + // The FIR boxchar argument has been split into a pair of distinct + // arguments that are in juxtaposition to each other. + auto newArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + if (fixup.second == 1) { + rewriter->setInsertionPointToStart(&func.front()); + auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; + auto box = rewriter->create( + loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); + func.getArgument(fixup.index + 1).replaceAllUsesWith(box); + func.front().eraseArgument(fixup.index + 1); + offset++; + } + } break; + case FixupTy::Codes::ReturnAsStore: { + // The value being returned is now being returned in memory (callee + // stack space) through a hidden reference argument. + auto newArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + offset++; + func.walk([&](mlir::ReturnOp ret) { + rewriter->setInsertionPoint(ret); + auto oldOper = ret.getOperand(0); + auto oldOperTy = ReferenceType::get(oldOper.getType()); + auto cast = rewriter->create(loc, oldOperTy, newArg); + rewriter->create(loc, oldOper, cast); + rewriter->create(loc); + ret.erase(); + }); + } break; + case FixupTy::Codes::ReturnType: { + // The function is still returning a value, but its type has likely + // changed to suit the target ABI convention. + func.walk([&](mlir::ReturnOp ret) { + rewriter->setInsertionPoint(ret); + auto oldOper = ret.getOperand(0); + auto oldOperTy = ReferenceType::get(oldOper.getType()); + auto mem = + rewriter->create(loc, newResTys[fixup.index]); + auto cast = rewriter->create(loc, oldOperTy, mem); + rewriter->create(loc, oldOper, cast); + mlir::Value load = rewriter->create(loc, mem); + rewriter->create(loc, load); + ret.erase(); + }); + } break; + case FixupTy::Codes::Split: { + // The FIR argument has been split into a pair of distinct arguments + // that are in juxtaposition to each other. (For COMPLEX value.) + auto newArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + if (fixup.second == 1) { + rewriter->setInsertionPointToStart(&func.front()); + auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; + auto undef = rewriter->create(loc, cplxTy); + auto iTy = rewriter->getIntegerType(32); + mlir::Value zero = rewriter->create( + loc, iTy, mlir::IntegerAttr::get(iTy, 0)); + mlir::Value one = rewriter->create( + loc, iTy, mlir::IntegerAttr::get(iTy, 1)); + auto cplx1 = rewriter->create( + loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), + zero); + auto cplx = rewriter->create(loc, cplxTy, cplx1, + newArg, one); + func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); + func.front().eraseArgument(fixup.index + 1); + offset++; + } + } break; + case FixupTy::Codes::Trailing: { + // The FIR argument has been split into a pair of distinct arguments. + // The first part of the pair appears in the original argument + // position. The second part of the pair is appended after all the + // original arguments. (Boxchar arguments.) + auto newBufArg = + func.front().insertArgument(fixup.index, newInTys[fixup.index]); + auto newLenArg = func.front().addArgument(trailingTys[fixup.second]); + auto boxTy = oldArgTys[fixup.index - offset]; + rewriter->setInsertionPointToStart(&func.front()); + auto box = + rewriter->create(loc, boxTy, newBufArg, newLenArg); + func.getArgument(fixup.index + 1).replaceAllUsesWith(box); + func.front().eraseArgument(fixup.index + 1); + } break; + } + } + } + + // Set the new type and finalize the arguments, etc. + newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); + auto newFuncTy = + mlir::FunctionType::get(func.getContext(), newInTys, newResTys); + LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); + func.setType(newFuncTy); + + for (auto &fixup : fixups) + if (fixup.finalizer) + (*fixup.finalizer)(func); + } + + inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { + if (auto attr = func.getArgAttrOfType(index, "llvm.sret")) + return true; + return false; + } + + /// Convert a complex return value. This can involve converting the return + /// value to a "hidden" first argument or packing the complex into a wide + /// GPR. + template + void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, + C &fixups) { + if (noComplexConversion) { + newResTys.push_back(cmplx); + return; + } + auto m = specifics->complexReturnType(cmplx.getElementType()); + assert(m.size() == 1); + auto &tup = m[0]; + auto attr = std::get(tup); + auto argTy = std::get(tup); + if (attr.isSRet()) { + unsigned argNo = newInTys.size(); + fixups.emplace_back( + FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { + func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); + }); + newInTys.push_back(argTy); + return; + } + fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); + newResTys.push_back(argTy); + } + + /// Convert a complex argument value. This can involve storing the value to + /// a temporary memory location or factoring the value into two distinct + /// arguments. + template + void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { + if (noComplexConversion) { + newInTys.push_back(cmplx); + return; + } + auto m = specifics->complexArgumentType(cmplx.getElementType()); + const auto fixupCode = + m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; + for (auto e : llvm::enumerate(m)) { + auto &tup = e.value(); + auto index = e.index(); + auto attr = std::get(tup); + auto argTy = std::get(tup); + auto argNo = newInTys.size(); + if (attr.isByVal()) { + if (auto align = attr.getAlignment()) + fixups.emplace_back( + FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { + func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); + func.setArgAttr(argNo, "llvm.align", + rewriter->getIntegerAttr( + rewriter->getIntegerType(32), align)); + }); + else + fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), + [=](mlir::FuncOp func) { + func.setArgAttr(argNo, "llvm.byval", + rewriter->getUnitAttr()); + }); + } else { + if (auto align = attr.getAlignment()) + fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { + func.setArgAttr( + argNo, "llvm.align", + rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); + }); + else + fixups.emplace_back(fixupCode, argNo, index); + } + newInTys.push_back(argTy); + } + } + +private: + // Replace `op` and remove it. + void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { + op->replaceAllUsesWith(newValues); + op->dropAllReferences(); + op->erase(); + } + + inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { + specifics = s; + rewriter = r; + } + + inline void clearMembers() { setMembers(nullptr, nullptr); } + + CodeGenSpecifics *specifics{}; + mlir::OpBuilder *rewriter; +}; // namespace +} // namespace + +std::unique_ptr> +fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { + return std::make_unique(options); +} Index: flang/lib/Optimizer/CodeGen/TypeConverter.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -0,0 +1,383 @@ +//===-- TypeConverter.h -- type conversion ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_CODEGEN_TYPECONVERTER_H +#define OPTIMIZER_CODEGEN_TYPECONVERTER_H + +namespace fir { + +/// FIR type converter +/// This converts FIR types to LLVM types (for now) +class LLVMTypeConverter : public mlir::LLVMTypeConverter { +public: + LLVMTypeConverter(mlir::ModuleOp module) + : mlir::LLVMTypeConverter(module.getContext()), + kindMapping(*getKindMapping(module)), uniquer(*getNameUniquer(module)), + specifics(CodeGenSpecifics::get(module.getContext(), + *getTargetTriple(module), + *getKindMapping(module))) { + LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); + + // Each conversion should return a value of type mlir::Type. + addConversion([&](BoxType box) { return convertBoxType(box); }); + addConversion([&](BoxCharType boxchar) { + LLVM_DEBUG(llvm::dbgs() << "type convert: " << boxchar << '\n'); + return unwrap( + convertType(specifics->boxcharMemoryType(boxchar.getEleTy()))); + }); + addConversion( + [&](BoxProcType boxproc) { return convertBoxProcType(boxproc); }); + addConversion( + [&](fir::CharacterType charTy) { return convertCharType(charTy); }); + addConversion( + [&](mlir::ComplexType cmplx) { return convertComplexType(cmplx); }); + addConversion( + [&](fir::ComplexType cmplx) { return convertComplexType(cmplx); }); + addConversion( + [&](fir::RecordType derived) { return convertRecordType(derived); }); + addConversion([&](fir::FieldType field) { + return mlir::IntegerType::get(field.getContext(), 32); + }); + addConversion([&](HeapType heap) { return convertPointerLike(heap); }); + addConversion([&](fir::IntegerType intTy) { + return mlir::IntegerType::get( + &getContext(), kindMapping.getIntegerBitsize(intTy.getFKind())); + }); + addConversion([&](LenType field) { + return mlir::IntegerType::get(field.getContext(), 32); + }); + addConversion([&](fir::LogicalType boolTy) { + return mlir::IntegerType::get( + &getContext(), kindMapping.getLogicalBitsize(boolTy.getFKind())); + }); + addConversion( + [&](fir::PointerType pointer) { return convertPointerLike(pointer); }); + addConversion( + [&](fir::RealType real) { return convertRealType(real.getFKind()); }); + addConversion( + [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); + addConversion( + [&](SequenceType sequence) { return convertSequenceType(sequence); }); + addConversion([&](TypeDescType tdesc) { + return convertTypeDescType(tdesc.getContext()); + }); + addConversion([&](fir::VectorType vecTy) { + return mlir::VectorType::get(llvm::ArrayRef(vecTy.getLen()), + unwrap(convertType(vecTy.getEleTy()))); + }); + addConversion([&](mlir::TupleType tuple) { + LLVM_DEBUG(llvm::dbgs() << "type convert: " << tuple << '\n'); + llvm::SmallVector inMembers; + tuple.getFlattenedTypes(inMembers); + llvm::SmallVector members; + for (auto mem : inMembers) + members.push_back(convertType(mem).cast()); + return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members, + /*isPacked=*/false); + }); + addConversion([&](mlir::NoneType none) { + return mlir::LLVM::LLVMStructType::getLiteral( + none.getContext(), llvm::None, /*isPacked=*/false); + }); + + // FIXME: https://reviews.llvm.org/D82831 introduced an automatic + // materialization of conversion around function calls that is not working + // well with fir lowering to llvm (incorrect llvm.mlir.cast are inserted). + // Workaround until better analysis: register a handler that does not insert + // any conversions. + addSourceMaterialization( + [&](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> llvm::Optional { + if (inputs.size() != 1) + return llvm::None; + return inputs[0]; + }); + // Similar FIXME workaround here (needed for compare.fir/select-type.fir + // tests). + addTargetMaterialization( + [&](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> llvm::Optional { + if (inputs.size() != 1) + return llvm::None; + return inputs[0]; + }); + } + + // i32 is used here because LLVM wants i32 constants when indexing into struct + // types. Indexing into other aggregate types is more flexible. + mlir::Type offsetType() { return mlir::IntegerType::get(&getContext(), 32); } + + // i64 can be used to index into aggregates like arrays + mlir::Type indexType() { return mlir::IntegerType::get(&getContext(), 64); } + + // TODO + bool requiresExtendedDesc() { return false; } + + // Magic value to indicate we do not know the rank of an entity, either + // because it is assumed rank or because we have not determined it yet. + static constexpr int unknownRank() { return -1; } + // This corresponds to the descriptor as defined ISO_Fortran_binding.h and the + // addendum defined in descriptor.h. + mlir::Type convertBoxType(BoxType box, int rank = unknownRank()) { + // (buffer*, ele-size, rank, type-descriptor, attribute, [dims]) + SmallVector parts; + mlir::Type ele = box.getEleTy(); + // remove fir.heap/fir.ref/fir.ptr + if (auto removeIndirection = fir::dyn_cast_ptrEleTy(ele)) + ele = removeIndirection; + auto eleTy = unwrap(convertType(ele)); + // buffer* + if (ele.isa() && eleTy.isa()) + parts.push_back(eleTy); + else + parts.push_back(mlir::LLVM::LLVMPointerType::get(eleTy)); + parts.push_back(getDescFieldTypeModel<1>()(&getContext())); + parts.push_back(getDescFieldTypeModel<2>()(&getContext())); + parts.push_back(getDescFieldTypeModel<3>()(&getContext())); + parts.push_back(getDescFieldTypeModel<4>()(&getContext())); + parts.push_back(getDescFieldTypeModel<5>()(&getContext())); + parts.push_back(getDescFieldTypeModel<6>()(&getContext())); + if (rank == unknownRank()) { + if (auto seqTy = ele.dyn_cast()) + rank = seqTy.getDimension(); + else + rank = 0; + } + if (rank > 0) { + auto rowTy = getDescFieldTypeModel<7>()(&getContext()); + parts.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, rank)); + } + // opt-type-ptr: i8* (see fir.tdesc) + if (requiresExtendedDesc()) { + parts.push_back(getExtendedDescFieldTypeModel<8>()(&getContext())); + parts.push_back(getExtendedDescFieldTypeModel<9>()(&getContext())); + auto rowTy = getExtendedDescFieldTypeModel<10>()(&getContext()); + unsigned numLenParams = 0; // FIXME + parts.push_back(mlir::LLVM::LLVMArrayType::get(rowTy, numLenParams)); + TODO("extended descriptor"); + } + return mlir::LLVM::LLVMPointerType::get( + mlir::LLVM::LLVMStructType::getLiteral(&getContext(), parts, + /*isPacked=*/false)); + } + + // fir.boxproc --> llvm<"{ any*, i8* }"> + mlir::Type convertBoxProcType(BoxProcType boxproc) { + auto funcTy = convertType(boxproc.getEleTy()); + auto ptrTy = mlir::LLVM::LLVMPointerType::get(unwrap(funcTy)); + auto i8PtrTy = mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(&getContext(), 8)); + llvm::SmallVector tuple = {ptrTy, i8PtrTy}; + return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), tuple, + /*isPacked=*/false); + } + + unsigned characterBitsize(fir::CharacterType charTy) { + return kindMapping.getCharacterBitsize(charTy.getFKind()); + } + + // fir.char --> llvm<"ix*"> where ix is scaled by kind mapping + mlir::Type convertCharType(fir::CharacterType charTy) { + auto iTy = mlir::IntegerType::get(&getContext(), characterBitsize(charTy)); + if (charTy.getLen() == fir::CharacterType::unknownLen()) + return iTy; + return mlir::LLVM::LLVMArrayType::get(iTy, charTy.getLen()); + } + + // Convert a complex value's element type based on its Fortran kind. + mlir::Type convertComplexPartType(fir::KindTy kind) { + auto realID = kindMapping.getComplexTypeID(kind); + return fromRealTypeID(realID, kind); + } + + // Use the target specifics to figure out how to map complex to LLVM IR. The + // use of complex values in function signatures is handled before conversion + // to LLVM IR dialect here. + // + // fir.complex | std.complex --> llvm<"{t,t}"> + template + mlir::Type convertComplexType(C cmplx) { + LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n'); + auto eleTy = cmplx.getElementType(); + return unwrap(convertType(specifics->complexMemoryType(eleTy))); + } + + // Get the default size of INTEGER. (The default size might have been set on + // the command line.) + mlir::Type getDefaultInt() { + return mlir::IntegerType::get( + &getContext(), + kindMapping.getIntegerBitsize(kindMapping.defaultIntegerKind())); + } + + static bool hasDynamicSize(mlir::Type type) { + if (auto charTy = type.dyn_cast()) + return charTy.getLen() == fir::CharacterType::unknownLen(); + return false; + } + + template + mlir::Type convertPointerLike(A &ty) { + mlir::Type eleTy = ty.getEleTy(); + // A sequence type is a special case. A sequence of runtime size on its + // interior dimensions lowers to a memory reference. In that case, we + // degenerate the array and do not want a the type to become `T**` but + // merely `T*`. + if (auto seqTy = eleTy.dyn_cast()) { + if (!seqTy.hasConstantShape() || hasDynamicSize(seqTy.getEleTy())) { + if (seqTy.hasConstantInterior()) + return unwrap(convertType(seqTy)); + eleTy = seqTy.getEleTy(); + } + } + // fir.ref is a special case because fir.box type is already + // a pointer to a Fortran descriptor at the LLVM IR level. This implies + // that a fir.ref, that is the address of fir.box is actually + // the same as a fir.box at the LLVM level. + // The distinction is kept in fir to denote when a descriptor is expected + // to be mutable (fir.ref) and when it is not (fir.box). + if (eleTy.isa()) + return unwrap(convertType(eleTy)); + + return mlir::LLVM::LLVMPointerType::get(unwrap(convertType(eleTy))); + } + + // convert a front-end kind value to either a std or LLVM IR dialect type + // fir.real --> llvm.anyfloat where anyfloat is a kind mapping + mlir::Type convertRealType(fir::KindTy kind) { + return fromRealTypeID(kindMapping.getRealTypeID(kind), kind); + } + + // fir.type --> llvm<"%name = { ty... }"> + mlir::Type convertRecordType(fir::RecordType derived) { + auto name = derived.getName(); + // The cache is needed to keep a unique mapping from name -> StructType + auto iter = identStructCache.find(name); + if (iter != identStructCache.end()) + return iter->second; + auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name); + identStructCache[name] = st; + llvm::SmallVector members; + for (auto mem : derived.getTypeList()) + members.push_back(convertType(mem.second).cast()); + st.setBody(members, /*isPacked=*/false); + return st; + } + + // fir.array --> llvm<"[...[c x any]]"> + mlir::Type convertSequenceType(SequenceType seq) { + auto baseTy = unwrap(convertType(seq.getEleTy())); + if (hasDynamicSize(seq.getEleTy())) + return mlir::LLVM::LLVMPointerType::get(baseTy); + auto shape = seq.getShape(); + auto constRows = seq.getConstantRows(); + if (constRows) { + decltype(constRows) i = constRows; + for (auto e : shape) { + baseTy = mlir::LLVM::LLVMArrayType::get(baseTy, e); + if (--i == 0) + break; + } + if (seq.hasConstantShape()) + return baseTy; + } + return mlir::LLVM::LLVMPointerType::get(baseTy); + } + + // fir.tdesc --> llvm<"i8*"> + // FIXME: for now use a void*, however pointer identity is not sufficient for + // the f18 object v. class distinction + mlir::Type convertTypeDescType(mlir::MLIRContext *ctx) { + return mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(&getContext(), 8)); + } + + /// Convert llvm::Type::TypeID to mlir::Type + mlir::Type fromRealTypeID(llvm::Type::TypeID typeID, fir::KindTy kind) { + switch (typeID) { + case llvm::Type::TypeID::HalfTyID: + return mlir::FloatType::getF16(&getContext()); + case llvm::Type::TypeID::FloatTyID: + return mlir::FloatType::getF32(&getContext()); + case llvm::Type::TypeID::DoubleTyID: + return mlir::FloatType::getF64(&getContext()); + case llvm::Type::TypeID::X86_FP80TyID: + return mlir::FloatType::getF80(&getContext()); + case llvm::Type::TypeID::FP128TyID: + return mlir::FloatType::getF128(&getContext()); + default: + emitError(UnknownLoc::get(&getContext())) + << "unsupported type: !fir.real<" << kind << ">"; + return {}; + } + } + + /// HACK: cloned from LLVMTypeConverter since this is private there + mlir::Type unwrap(mlir::Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; + } + + /// Returns false iff the sequence type has a shape and the shape is constant. + static bool unknownShape(SequenceType::Shape shape) { + // does the shape even exist? + auto size = shape.size(); + if (size == 0) + return true; + // if it exists, are any dimensions deferred? + for (decltype(size) i = 0, sz = size; i < sz; ++i) + if (shape[i] == SequenceType::getUnknownExtent()) + return true; + return false; + } + + /// Does this record type have dynamically inlined subobjects? Note: this + /// should not look through references as they are not inlined. + static bool dynamicallySized(fir::RecordType seqTy) { + for (auto field : seqTy.getTypeList()) { + if (auto arr = field.second.dyn_cast()) { + if (unknownShape(arr.getShape())) + return true; + } else if (auto rec = field.second.dyn_cast()) { + if (dynamicallySized(rec)) + return true; + } + } + return false; + } + + static bool dynamicallySized(mlir::Type ty) { + if (auto arr = ty.dyn_cast()) + ty = arr.getEleTy(); + if (auto rec = ty.dyn_cast()) + return dynamicallySized(rec); + return false; + } + + NameUniquer &getUniquer() { return uniquer; } + + KindMapping &getKindMap() { return kindMapping; } + +private: + KindMapping kindMapping; + NameUniquer &uniquer; + std::unique_ptr specifics; + static StringMap identStructCache; +}; + +} // namespace fir + +#endif // OPTIMIZER_CODEGEN_TYPECONVERTER_H Index: flang/lib/Optimizer/Transforms/Inliner.cpp =================================================================== --- flang/lib/Optimizer/Transforms/Inliner.cpp +++ flang/lib/Optimizer/Transforms/Inliner.cpp @@ -18,7 +18,11 @@ llvm::cl::init(false)); /// Should we inline the callable `op` into region `reg`? -bool fir::canLegallyInline(mlir::Operation *op, mlir::Region *reg, - mlir::BlockAndValueMapping &map) { +bool fir::canLegallyInline(mlir::Operation *, mlir::Region *, bool, + mlir::BlockAndValueMapping &) { + return aggressivelyInline; +} + +bool fir::canLegallyInline(mlir::Operation *, mlir::Operation *, bool) { return aggressivelyInline; }