Index: flang/include/flang/Optimizer/CodeGen/CGOps.td =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -0,0 +1,177 @@ +//===-- CGOps.td - FIR operation definitions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Definition of the FIRCG dialect operations +/// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_DIALECT_FIRCG_OPS +#define FORTRAN_DIALECT_FIRCG_OPS + +include "mlir/IR/SymbolInterfaces.td" +include "flang/Optimizer/Dialect/FIRTypes.td" + +def fircg_Dialect : Dialect { + let name = "fircg"; + let cppNamespace = "::fir::cg"; +} + +// Base class for FIR CG operations. +// All operations automatically get a prefix of "fircg.". +class fircg_Op traits> + : Op; + +// Extended embox operation. +def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference being emboxed. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices for subobject slicing. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref and shape arguments are mandatory. The rest are optional. + }]; + + let arguments = (ins + AnyReferenceLike:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$lenParams + ); + let results = (outs fir_BoxType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^)? (`typeparams` $lenParams^)? attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + // The rank of the entity being emboxed + unsigned getRank() { return shape().size(); } + + // The rank of the result. A slice op can reduce the rank. + unsigned getOutRank(); + + // The shape operands are mandatory and always start at 1. + unsigned shapeOffset() { return 1; } + unsigned shiftOffset() { return shapeOffset() + shape().size(); } + unsigned sliceOffset() { return shiftOffset() + shift().size(); } + unsigned subcomponentOffset() { return sliceOffset() + slice().size(); } + unsigned lenParamOffset() { + return subcomponentOffset() + subcomponent().size(); + } + }]; +} + +// Extended rebox operation. +def fircg_XReboxOp : fircg_Op<"ext_rebox", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial rebox op will + be converted to an extended rebox. This op will have the following sets of + arguments. + + - box: The box being reboxed. + - shape: A vector that is the new runtime shape for the array + - shift: A vector that is the new runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices for subobject slicing. + + The box argument is mandatory, the other arguments are optional. + There must not both be a shape and slice/subcomponent arguments + }]; + + let arguments = (ins + fir_BoxType:$box, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent + ); + let results = (outs fir_BoxType); + + let assemblyFormat = [{ + $box (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^) ? attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + // The rank of the entity being reboxed + unsigned getRank(); + // The rank of the result box + unsigned getOutRank(); + }]; +} + + +// Extended array coordinate operation. +def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference of the array's data. It can be a fir.box if + the underlying data is not contiguous. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices that describe subobject slicing. + - indices: A vector of runtime values that describe the coordinate of + the element of the array to be computed. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref and indices arguments are mandatory. + The shape argument is mandatory if the memref is not a box, and should be + omitted otherwise. The rest of the arguments are optional. + }]; + + let arguments = (ins + AnyRefOrBox:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$indices, + Variadic:$lenParams + ); + let results = (outs fir_ReferenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^)? `<`$indices`>` (`typeparams` $lenParams^)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + unsigned getRank(); + }]; +} + +#endif Index: flang/include/flang/Optimizer/CodeGen/CGPasses.td =================================================================== --- flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef FLANG_OPTIMIZER_CODEGEN_PASSES -#define FLANG_OPTIMIZER_CODEGEN_PASSES +#ifndef FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES +#define FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES include "mlir/Pass/PassBase.td" @@ -22,7 +22,13 @@ Fuse specific subgraphs into single Ops for code generation. }]; let constructor = "fir::createFirCodeGenRewritePass()"; - let dependentDialects = ["fir::FIROpsDialect"]; + let dependentDialects = [ + "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::BuiltinDialect", + "mlir::LLVM::LLVMDialect", "mlir::omp::OpenMPDialect" + ]; + let statistics = [ + Statistic<"numDCE", "num-dce'd", "Number of operations eliminated"> + ]; } -#endif // FLANG_OPTIMIZER_CODEGEN_PASSES +#endif // FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES Index: flang/include/flang/Optimizer/CodeGen/CMakeLists.txt =================================================================== --- flang/include/flang/Optimizer/CodeGen/CMakeLists.txt +++ flang/include/flang/Optimizer/CodeGen/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS CGOps.td) +mlir_tablegen(CGOps.h.inc -gen-op-decls) +mlir_tablegen(CGOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(CGOpsIncGen) set(LLVM_TARGET_DEFINITIONS CGPasses.td) mlir_tablegen(CGPasses.h.inc -gen-pass-decls -name OptCodeGen) Index: flang/include/flang/Optimizer/Dialect/FIRDialect.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRDialect.h +++ flang/include/flang/Optimizer/Dialect/FIRDialect.h @@ -40,6 +40,16 @@ void registerTypes(); }; +/// The FIR codegen dialect is a dialect containing a small set of transient +/// operations used exclusively during code generation. +class FIRCodeGenDialect final : public mlir::Dialect { +public: + explicit FIRCodeGenDialect(mlir::MLIRContext *ctx); + virtual ~FIRCodeGenDialect(); + + static llvm::StringRef getDialectNamespace() { return "fircg"; } +}; + } // namespace fir #endif // FORTRAN_OPTIMIZER_DIALECT_FIRDIALECT_H Index: flang/include/flang/Optimizer/OptPasses.h =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/OptPasses.h @@ -0,0 +1,22 @@ +//===-- Optimizer/OptPasses.h -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_OPTPASSES_H +#define FORTRAN_OPTIMIZER_OPTPASSES_H + +#include "flang/Optimizer/CodeGen/CodeGen.h" + +namespace fir { +/// Register the passes in the flang/Optimizer directory. +/// TODO: Consider merging the registration of all passes in 1 function. +inline void registerOptimizerPasses() { + registerOptCodeGenPasses(); +} +} // namespace fir + +#endif // FORTRAN_OPTIMIZER_OPTPASSES_H Index: flang/include/flang/Optimizer/Support/InitFIR.h =================================================================== --- flang/include/flang/Optimizer/Support/InitFIR.h +++ flang/include/flang/Optimizer/Support/InitFIR.h @@ -26,10 +26,10 @@ // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ - mlir::AffineDialect, FIROpsDialect, mlir::LLVM::LLVMDialect, \ - mlir::acc::OpenACCDialect, mlir::omp::OpenMPDialect, \ - mlir::scf::SCFDialect, mlir::StandardOpsDialect, \ - mlir::vector::VectorDialect + mlir::AffineDialect, FIROpsDialect, FIRCodeGenDialect, \ + mlir::LLVM::LLVMDialect, mlir::acc::OpenACCDialect, \ + mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \ + mlir::StandardOpsDialect, mlir::vector::VectorDialect /// Register all the dialects used by flang. inline void registerDialects(mlir::DialectRegistry ®istry) { @@ -45,7 +45,7 @@ /// Register the standard passes we use. This comes from registerAllPasses(), /// but is a smaller set since we aren't using many of the passes found there. -inline void registerFIRPasses() { +inline void registerMLIRPassesForFortranTools() { mlir::registerCanonicalizerPass(); mlir::registerCSEPass(); mlir::registerAffineLoopFusionPass(); @@ -71,6 +71,9 @@ mlir::registerConvertAffineToStandardPass(); } +/// Register the interfaces needed to lower to LLVM IR. +void registerLLVMTranslation(mlir::MLIRContext &context); + } // namespace fir::support #endif // FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H Index: flang/lib/Optimizer/CMakeLists.txt =================================================================== --- flang/lib/Optimizer/CMakeLists.txt +++ flang/lib/Optimizer/CMakeLists.txt @@ -10,11 +10,16 @@ Support/InternalNames.cpp Support/KindMapping.cpp + CodeGen/CGOps.cpp + CodeGen/PreCGRewrite.cpp + Transforms/Inliner.cpp DEPENDS FIROpsIncGen + FIROptCodeGenPassIncGen FIROptTransformsPassIncGen + CGOpsIncGen ${dialect_libs} LINK_LIBS Index: flang/lib/Optimizer/CodeGen/CGOps.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.h @@ -0,0 +1,24 @@ +//===-- CGOps.h -------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_CODEGEN_CGOPS_H +#define OPTIMIZER_CODEGEN_CGOPS_H + +#include "flang/Optimizer/Dialect/FIRType.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.h.inc" + +#endif Index: flang/lib/Optimizer/CodeGen/CGOps.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.cpp @@ -0,0 +1,64 @@ +//===-- CGOps.cpp -- FIR codegen operations -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" + +/// FIR codegen dialect constructor. +fir::FIRCodeGenDialect::FIRCodeGenDialect(mlir::MLIRContext *ctx) + : mlir::Dialect("fircg", ctx, mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" + >(); +} + +// anchor the class vtable to this compilation unit +fir::FIRCodeGenDialect::~FIRCodeGenDialect() { + // do nothing +} + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" + +unsigned fir::cg::XEmboxOp::getOutRank() { + if (slice().empty()) + return getRank(); + auto outRank = fir::SliceOp::getOutputRank(slice()); + assert(outRank >= 1); + return outRank; +} + +unsigned fir::cg::XReboxOp::getOutRank() { + if (auto seqTy = + fir::dyn_cast_ptrOrBoxEleTy(getType()).dyn_cast()) + return seqTy.getDimension(); + return 0; +} + +unsigned fir::cg::XReboxOp::getRank() { + if (auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(box().getType()) + .dyn_cast()) + return seqTy.getDimension(); + return 0; +} + +unsigned fir::cg::XArrayCoorOp::getRank() { + auto memrefTy = memref().getType(); + if (memrefTy.isa()) + if (auto seqty = + fir::dyn_cast_ptrOrBoxEleTy(memrefTy).dyn_cast()) + return seqty.getDimension(); + return shape().size(); +} Index: flang/lib/Optimizer/CodeGen/PassDetail.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PassDetail.h @@ -0,0 +1,26 @@ +//===- PassDetail.h - Optimizer code gen Pass class details -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTMIZER_CODEGEN_PASSDETAIL_H +#define OPTMIZER_CODEGEN_PASSDETAIL_H + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace fir { + +#define GEN_PASS_CLASSES +#include "flang/Optimizer/CodeGen/CGPasses.h.inc" + +} // namespace fir + +#endif // OPTMIZER_CODEGEN_PASSDETAIL_H Index: flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -0,0 +1,315 @@ +//===-- PreCGRewrite.cpp --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "PassDetail.h" +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" + +//===----------------------------------------------------------------------===// +// Codegen rewrite: rewriting of subgraphs of ops +//===----------------------------------------------------------------------===// + +using namespace fir; + +#define DEBUG_TYPE "flang-codegen-rewrite" + +static void populateShape(llvm::SmallVectorImpl &vec, + ShapeOp shape) { + vec.append(shape.extents().begin(), shape.extents().end()); +} + +// Operands of fir.shape_shift split into two vectors. +static void populateShapeAndShift(llvm::SmallVectorImpl &shapeVec, + llvm::SmallVectorImpl &shiftVec, + ShapeShiftOp shift) { + auto endIter = shift.pairs().end(); + for (auto i = shift.pairs().begin(); i != endIter;) { + shiftVec.push_back(*i++); + shapeVec.push_back(*i++); + } +} + +static void populateShift(llvm::SmallVectorImpl &vec, + ShiftOp shift) { + vec.append(shift.origins().begin(), shift.origins().end()); +} + +namespace { + +/// Convert fir.embox to the extended form where necessary. +/// +/// The embox operation can take arguments that specify multidimensional array +/// properties at runtime. These properties may be shared between distinct +/// objects that have the same properties. Before we lower these small DAGs to +/// LLVM-IR, we gather all the information into a single extended operation. For +/// example, +/// ``` +/// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> +/// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> +/// %3 = fir.embox %0 (%1) [%2] : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> +/// ``` +/// can be rewritten as +/// ``` +/// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref>, index, index, index, index, index) -> !fir.box> +/// ``` +class EmboxConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(EmboxOp embox, + mlir::PatternRewriter &rewriter) const override { + auto shapeVal = embox.getShape(); + // If the embox does not include a shape, then do not convert it + if (shapeVal) + return rewriteDynamicShape(embox, rewriter, shapeVal); + if (auto boxTy = embox.getType().dyn_cast()) + if (auto seqTy = boxTy.getEleTy().dyn_cast()) + if (seqTy.hasConstantShape()) + return rewriteStaticShape(embox, rewriter, seqTy); + return mlir::failure(); + } + + mlir::LogicalResult rewriteStaticShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + SequenceType seqTy) const { + auto loc = embox.getLoc(); + llvm::SmallVector shapeOpers; + auto idxTy = rewriter.getIndexType(); + for (auto ext : seqTy.getShape()) { + auto iAttr = rewriter.getIndexAttr(ext); + auto extVal = rewriter.create(loc, idxTy, iAttr); + shapeOpers.push_back(extVal); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, + llvm::None, llvm::None, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } + + mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + mlir::Value shapeVal) const { + auto loc = embox.getLoc(); + auto shapeOp = dyn_cast(shapeVal.getDefiningOp()); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (shapeOp) { + populateShape(shapeOpers, shapeOp); + } else { + auto shiftOp = dyn_cast(shapeVal.getDefiningOp()); + assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = embox.getSlice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert fir.rebox to the extended form where necessary. +/// +/// For example, +/// ``` +/// %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> !fir.box> +/// ``` +/// converted to +/// ``` +/// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box>, index, index) -> !fir.box> +/// ``` +class ReboxConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ReboxOp rebox, + mlir::PatternRewriter &rewriter) const override { + auto loc = rebox.getLoc(); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (auto shapeVal = rebox.shape()) { + if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + populateShape(shapeOpers, shapeOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShift(shiftOpers, shiftOp); + else + return mlir::failure(); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = rebox.slice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + + auto xRebox = rewriter.create( + loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers, + subcompOpers); + LLVM_DEBUG(llvm::dbgs() + << "rewriting " << rebox << " to " << xRebox << '\n'); + rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert all fir.array_coor to the extended form. +/// +/// For example, +/// ``` +/// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref +/// ``` +/// converted to +/// ``` +/// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref>, index, index, index, index, index, index) -> !fir.ref +/// ``` +class ArrayCoorConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayCoorOp arrCoor, + mlir::PatternRewriter &rewriter) const override { + auto loc = arrCoor.getLoc(); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (auto shapeVal = arrCoor.shape()) { + if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + populateShape(shapeOpers, shapeOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShift(shiftOpers, shiftOp); + else + return mlir::failure(); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = arrCoor.slice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xArrCoor = rewriter.create( + loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams()); + LLVM_DEBUG(llvm::dbgs() + << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); + rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); + return mlir::success(); + } +}; + +class CodeGenRewrite : public CodeGenRewriteBase { +public: + void runOn(mlir::Operation *op, mlir::Region ®ion) { + auto &context = getContext(); + mlir::OpBuilder rewriter(&context); + mlir::ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](EmboxOp embox) { + return !(embox.getShape() || + embox.getType().cast().getEleTy().isa()); + }); + mlir::OwningRewritePatternList patterns; + patterns.insert( + &context); + if (mlir::failed( + mlir::applyPartialConversion(op, target, std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(&context), + "error in running the pre-codegen conversions"); + signalPassFailure(); + } + // Erase any residual. + simplifyRegion(region); + } + + void runOnOperation() override final { + // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp. + auto mod = getOperation(); + for (auto func : mod.getOps()) + runOn(func, func.getBody()); + for (auto global : mod.getOps()) + runOn(global, global.getRegion()); + } + + // Clean up the region. + void simplifyRegion(mlir::Region ®ion) { + for (auto &block : region.getBlocks()) + for (auto &op : block.getOperations()) { + for (auto ® : op.getRegions()) + simplifyRegion(reg); + maybeEraseOp(&op); + } + doDCE(); + } + + /// Run a simple DCE cleanup to remove any dead code after the rewrites. + void doDCE() { + std::vector workList; + workList.swap(opsToErase); + while (!workList.empty()) { + for (auto *op : workList) { + std::vector opOperands(op->operand_begin(), + op->operand_end()); + LLVM_DEBUG(llvm::dbgs() << "DCE on " << *op << '\n'); + ++numDCE; + op->erase(); + for (auto opnd : opOperands) + maybeEraseOp(opnd.getDefiningOp()); + } + workList.clear(); + workList.swap(opsToErase); + } + } + + void maybeEraseOp(mlir::Operation *op) { + if (!op) + return; + if (op->hasTrait()) + return; + if (mlir::isOpTriviallyDead(op)) + opsToErase.push_back(op); + } + +private: + std::vector opsToErase; +}; + +} // namespace + +std::unique_ptr fir::createFirCodeGenRewritePass() { + return std::make_unique(); +} Index: flang/test/Fir/cg-ops.fir =================================================================== --- /dev/null +++ flang/test/Fir/cg-ops.fir @@ -0,0 +1,30 @@ +// RUN: fir-opt --cg-rewrite %s | FileCheck %s + +// CHECK-LABEL: func @codegen( +// CHECK-SAME: %[[arg:.*]]: !fir +func @codegen(%addr : !fir.ref>) { + // CHECK: %[[zero:.*]] = constant 0 : index + %0 = constant 0 : index + %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1> + %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1> + // CHECK: %[[box:.*]] = fircg.ext_embox %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref>, index, index, index, index, index) -> !fir.box> + %3 = fir.embox %addr (%1) [%2] : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> + // CHECK: fircg.ext_array_coor %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]]<%[[zero]]> : (!fir.ref>, index, index, index, index, index, index) -> !fir.ref + %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref + // CHECK: fircg.ext_rebox %[[box]](%[[zero]]) origin %[[zero]] : (!fir.box>, index, index) -> !fir.box> + %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> !fir.box> + return +} + +// CHECK-LABEL: fir.global @box_global +fir.global @box_global : !fir.box> { + // CHECK: %[[arr:.*]] = fir.zero_bits !fir.ref + %arr = fir.zero_bits !fir.ref> + // CHECK: %[[zero:.*]] = constant 0 : index + %0 = constant 0 : index + %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1> + %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1> + // CHECK: fircg.ext_embox %[[arr]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref>, index, index, index, index, index) -> !fir.box> + %3 = fir.embox %arr (%1) [%2] : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> + fir.has_value %3 : !fir.box> +} Index: flang/tools/fir-opt/fir-opt.cpp =================================================================== --- flang/tools/fir-opt/fir-opt.cpp +++ flang/tools/fir-opt/fir-opt.cpp @@ -12,14 +12,16 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/MlirOptMain.h" +#include "flang/Optimizer/OptPasses.h" #include "flang/Optimizer/Support/InitFIR.h" using namespace mlir; int main(int argc, char **argv) { - fir::support::registerFIRPasses(); + fir::support::registerMLIRPassesForFortranTools(); + fir::registerOptimizerPasses(); DialectRegistry registry; fir::support::registerDialects(registry); return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n", - registry, /*preloadDialectsInContext*/ false)); + registry, /*preloadDialectsInContext=*/false)); } Index: flang/tools/tco/tco.cpp =================================================================== --- flang/tools/tco/tco.cpp +++ flang/tools/tco/tco.cpp @@ -106,7 +106,7 @@ } int main(int argc, char **argv) { - fir::support::registerFIRPasses(); + fir::support::registerMLIRPassesForFortranTools(); [[maybe_unused]] InitLLVM y(argc, argv); mlir::registerPassManagerCLOptions(); mlir::PassPipelineCLParser passPipe("", "Compiler passes to run");