diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td --- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -16,6 +16,19 @@ include "mlir/Pass/PassBase.td" +def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> { + let summary = "Convert FIR dialect to LLVM-IR dialect"; + let description = [{ + Convert the FIR dialect to the LLVM-IR dialect of MLIR. This conversion + will also convert ops in the standard and FIRCG dialects. + }]; + let constructor = "::fir::createFIRToLLVMPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::BuiltinDialect", + "mlir::LLVM::LLVMDialect", "mlir::omp::OpenMPDialect" + ]; +} + def CodeGenRewrite : Pass<"cg-rewrite"> { let summary = "Rewrite some FIR ops into their code-gen forms."; let description = [{ diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -1,5 +1,6 @@ add_flang_library(FIRCodeGen CGOps.cpp + CodeGen.cpp PreCGRewrite.cpp DEPENDS diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -0,0 +1,205 @@ +//===-- CodeGen.cpp -- bridge to lower to LLVM ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/ArrayRef.h" + +#define DEBUG_TYPE "flang-codegen" + +// fir::LLVMTypeConverter for converting to LLVM IR dialect types. +#include "TypeConverter.h" + +namespace { +/// FIR conversion pattern template +template +class FIROpConversion : public mlir::ConvertOpToLLVMPattern { +public: + explicit FIROpConversion(fir::LLVMTypeConverter &lowering) + : mlir::ConvertOpToLLVMPattern(lowering) {} + +protected: + mlir::Type convertType(mlir::Type ty) const { + return lowerTy().convertType(ty); + } + + fir::LLVMTypeConverter &lowerTy() const { + return *static_cast(this->getTypeConverter()); + } +}; +} // namespace + +namespace { +struct AddrOfOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(addr.getType()); + rewriter.replaceOpWithNewOp( + addr, ty, addr.symbol().getRootReference().getValue()); + return success(); + } +}; + +struct HasValueOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +struct GlobalOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::GlobalOp global, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto tyAttr = convertType(global.getType()); + if (global.getType().isa()) + tyAttr = tyAttr.cast().getElementType(); + auto loc = global.getLoc(); + mlir::Attribute initAttr{}; + if (global.initVal()) + initAttr = global.initVal().getValue(); + auto linkage = convertLinkage(global.linkName()); + auto isConst = global.constant().hasValue(); + auto g = rewriter.create( + loc, tyAttr, isConst, linkage, global.sym_name(), initAttr); + auto &gr = g.getInitializerRegion(); + rewriter.inlineRegionBefore(global.region(), gr, gr.end()); + if (!gr.empty()) { + // Replace insert_on_range with a constant dense attribute if the + // initialization is on the full range. + auto insertOnRangeOps = gr.front().getOps(); + for (auto insertOp : insertOnRangeOps) { + if (isFullRange(insertOp.coor(), insertOp.getType())) { + auto seqTyAttr = convertType(insertOp.getType()); + auto *op = insertOp.val().getDefiningOp(); + auto constant = mlir::dyn_cast(op); + if (!constant) { + auto convertOp = mlir::dyn_cast(op); + if (!convertOp) + continue; + constant = cast( + convertOp.value().getDefiningOp()); + } + mlir::Type vecType = mlir::VectorType::get( + insertOp.getType().getShape(), constant.getType()); + auto denseAttr = mlir::DenseElementsAttr::get( + vecType.cast(), constant.value()); + rewriter.setInsertionPointAfter(insertOp); + rewriter.replaceOpWithNewOp( + insertOp, seqTyAttr, denseAttr); + } + } + } + rewriter.eraseOp(global); + return success(); + } + + bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const { + auto extents = seqTy.getShape(); + if (indexes.size() / 2 != extents.size()) + return false; + for (unsigned i = 0; i < indexes.size(); i += 2) { + if (indexes[i].cast().getInt() != 0) + return false; + if (indexes[i + 1].cast().getInt() != extents[i / 2] - 1) + return false; + } + return true; + } + + mlir::LLVM::Linkage convertLinkage(Optional optLinkage) const { + if (optLinkage.hasValue()) { + auto name = optLinkage.getValue(); + if (name == "internal") + return mlir::LLVM::Linkage::Internal; + if (name == "linkonce") + return mlir::LLVM::Linkage::Linkonce; + if (name == "common") + return mlir::LLVM::Linkage::Common; + if (name == "weak") + return mlir::LLVM::Linkage::Weak; + } + return mlir::LLVM::Linkage::External; + } +}; + +// convert to LLVM IR dialect `undef` +struct UndefOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::UndefOp undef, OpAdaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + undef, convertType(undef.getType())); + return success(); + } +}; +} // namespace + +namespace { +/// Convert FIR dialect to LLVM dialect +/// +/// This pass lowers all FIR dialect operations to LLVM IR dialect. An +/// MLIR pass is used to lower residual Std dialect to LLVM IR dialect. +/// +/// This pass is not complete yet. We are upstreaming it in small patches. +class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase { +public: + mlir::ModuleOp getModule() { return getOperation(); } + + void runOnOperation() override final { + auto *context = getModule().getContext(); + fir::LLVMTypeConverter typeConverter{getModule()}; + auto loc = mlir::UnknownLoc::get(context); + mlir::OwningRewritePatternList pattern(context); + pattern.insert(typeConverter); + mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); + mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, + pattern); + mlir::ConversionTarget target{*context}; + target.addLegalDialect(); + + // required NOPs for applying a full conversion + target.addLegalOp(); + + // apply the patterns + if (mlir::failed(mlir::applyFullConversion(getModule(), target, + std::move(pattern)))) { + mlir::emitError(loc, "error in converting to LLVM-IR dialect\n"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr fir::createFIRToLLVMPass() { + return std::make_unique(); +} diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -0,0 +1,85 @@ +//===-- TypeConverter.h -- type conversion ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H +#define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H + +#include "llvm/Support/Debug.h" + +namespace fir { + +/// FIR type converter +/// This converts FIR types to LLVM types (for now) +class LLVMTypeConverter : public mlir::LLVMTypeConverter { +public: + LLVMTypeConverter(mlir::ModuleOp module) + : mlir::LLVMTypeConverter(module.getContext()) { + LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n"); + + // Each conversion should return a value of type mlir::Type. + addConversion( + [&](fir::ReferenceType ref) { return convertPointerLike(ref); }); + addConversion( + [&](SequenceType sequence) { return convertSequenceType(sequence); }); + } + + template + mlir::Type convertPointerLike(A &ty) { + mlir::Type eleTy = ty.getEleTy(); + // A sequence type is a special case. A sequence of runtime size on its + // interior dimensions lowers to a memory reference. In that case, we + // degenerate the array and do not want a the type to become `T**` but + // merely `T*`. + if (auto seqTy = eleTy.dyn_cast()) { + if (!seqTy.hasConstantShape() || + characterWithDynamicLen(seqTy.getEleTy())) { + if (seqTy.hasConstantInterior()) + return convertType(seqTy); + eleTy = seqTy.getEleTy(); + } + } + // fir.ref is a special case because fir.box type is already + // a pointer to a Fortran descriptor at the LLVM IR level. This implies + // that a fir.ref, that is the address of fir.box is actually + // the same as a fir.box at the LLVM level. + // The distinction is kept in fir to denote when a descriptor is expected + // to be mutable (fir.ref) and when it is not (fir.box). + if (eleTy.isa()) + return convertType(eleTy); + + return mlir::LLVM::LLVMPointerType::get(convertType(eleTy)); + } + + // fir.array --> llvm<"[...[c x any]]"> + mlir::Type convertSequenceType(SequenceType seq) { + auto baseTy = convertType(seq.getEleTy()); + if (characterWithDynamicLen(seq.getEleTy())) + return mlir::LLVM::LLVMPointerType::get(baseTy); + auto shape = seq.getShape(); + auto constRows = seq.getConstantRows(); + if (constRows) { + decltype(constRows) i = constRows; + for (auto e : shape) { + baseTy = mlir::LLVM::LLVMArrayType::get(baseTy, e); + if (--i == 0) + break; + } + if (seq.hasConstantShape()) + return baseTy; + } + return mlir::LLVM::LLVMPointerType::get(baseTy); + } +}; + +} // namespace fir + +#endif // FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/convert-to-llvm.fir @@ -0,0 +1,83 @@ +// RUN: fir-opt --split-input-file --fir-to-llvm-ir %s | FileCheck %s + +// Test simple global LLVM conversion + +fir.global @g_i0 : i32 { + %1 = arith.constant 0 : i32 + fir.has_value %1 : i32 +} + +// CHECK: llvm.mlir.global external @g_i0() : i32 { +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: llvm.return %[[C0]] : i32 +// CHECK: } + +// ----- + +fir.global @g_ci5 constant : i32 { + %c = arith.constant 5 : i32 + fir.has_value %c : i32 +} + +// CHECK: llvm.mlir.global external constant @g_ci5() : i32 { +// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : i32) : i32 +// CHECK: llvm.return %[[C5]] : i32 +// CHECK: } + +// ----- + +fir.global internal @i_i515 (515:i32) : i32 +// CHECK: llvm.mlir.global internal @i_i515(515 : i32) : i32 + +// ----- + +fir.global common @C_i511 (0:i32) : i32 +// CHECK: llvm.mlir.global common @C_i511(0 : i32) : i32 + +// ----- + +fir.global weak @w_i86 (86:i32) : i32 +// CHECK: llvm.mlir.global weak @w_i86(86 : i32) : i32 + +// ----- + +fir.global linkonce @w_i86 (86:i32) : i32 +// CHECK: llvm.mlir.global linkonce @w_i86(86 : i32) : i32 + +// ----- + +// Test conversion of fir.address_of with fir.global + +func @f1() { + %0 = fir.address_of(@symbol) : !fir.ref + return +} + +fir.global @symbol : i64 { + %0 = arith.constant 1 : i64 + fir.has_value %0 : i64 +} + +// CHECK: %{{.*}} = llvm.mlir.addressof @[[SYMBOL:.*]] : !llvm.ptr + +// CHECK: llvm.mlir.global external @[[SYMBOL]]() : i64 { +// CHECK: %{{.*}} = llvm.mlir.constant(1 : i64) : i64 +// CHECK: llvm.return %{{.*}} : i64 +// CHECK: } + +// ----- + +// Test global with insert_on_range operation covering the full array +// in initializer region. + +fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { + %c0_i32 = arith.constant 1 : i32 + %0 = fir.undefined !fir.array<32x32xi32> + %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32> + fir.has_value %2 : !fir.array<32x32xi32> +} + +// CHECK: llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x array<32 x i32>> { +// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> +// CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> +// CHECK: }