Index: flang/include/flang/Optimizer/CodeGen/CGOps.td =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -0,0 +1,177 @@ +//===-- CGOps.td - FIR operation definitions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Definition of the FIRCG dialect operations +/// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_DIALECT_FIRCG_OPS +#define FORTRAN_DIALECT_FIRCG_OPS + +include "mlir/IR/SymbolInterfaces.td" +include "flang/Optimizer/Dialect/FIRTypes.td" + +def fircg_Dialect : Dialect { + let name = "fircg"; + let cppNamespace = "::fir::cg"; +} + +// Base class for FIR CG operations. +// All operations automatically get a prefix of "fircg.". +class fircg_Op traits> + : Op; + +// Extended embox operation. +def fircg_XEmboxOp : fircg_Op<"ext_embox", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference being emboxed. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices for subobject slicing. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref and shape arguments are mandatory. The rest are optional. + }]; + + let arguments = (ins + AnyReferenceLike:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$lenParams + ); + let results = (outs fir_BoxType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^)? (`typeparams` $lenParams^)? attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + // The rank of the entity being emboxed + unsigned getRank() { return shape().size(); } + + // The rank of the result. A slice op can reduce the rank. + unsigned getOutRank(); + + // The shape operands are mandatory and always start at 1. + unsigned shapeOffset() { return 1; } + unsigned shiftOffset() { return shapeOffset() + shape().size(); } + unsigned sliceOffset() { return shiftOffset() + shift().size(); } + unsigned subcomponentOffset() { return sliceOffset() + slice().size(); } + unsigned lenParamOffset() { + return subcomponentOffset() + subcomponent().size(); + } + }]; +} + +// Extended rebox operation. +def fircg_XReboxOp : fircg_Op<"ext_rebox", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial rebox op will + be converted to an extended rebox. This op will have the following sets of + arguments. + + - box: The box being reboxed. + - shape: A vector that is the new runtime shape for the array + - shift: A vector that is the new runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices for subobject slicing. + + The box argument is mandatory, the other arguments are optional. + There must not both be a shape and slice/subcomponent arguments + }]; + + let arguments = (ins + fir_BoxType:$box, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent + ); + let results = (outs fir_BoxType); + + let assemblyFormat = [{ + $box (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^) ? attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + // The rank of the entity being reboxed + unsigned getRank(); + // The rank of the result box + unsigned getOutRank(); + }]; +} + + +// Extended array coordinate operation. +def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]> { + let summary = "for internal conversion only"; + + let description = [{ + Prior to lowering to LLVM IR dialect, a non-scalar non-trivial embox op will + be converted to an extended embox. This op will have the following sets of + arguments. + + - memref: The memory reference of the array's data. It can be a fir.box if + the underlying data is not contiguous. + - shape: A vector that is the runtime shape of the underlying array. + - shift: A vector that is the runtime origin of the first element. + The default is a vector of the value 1. + - slice: A vector of triples that describe an array slice. + - subcomponent: A vector of indices that describe subobject slicing. + - indices: A vector of runtime values that describe the coordinate of + the element of the array to be computed. + - LEN type parameters: A vector of runtime LEN type parameters that + describe an correspond to the elemental derived type. + + The memref and indices arguments are mandatory. + The shape argument is mandatory if the memref is not a box, and should be + omitted otherwise. The rest of the arguments are optional. + }]; + + let arguments = (ins + AnyRefOrBox:$memref, + Variadic:$shape, + Variadic:$shift, + Variadic:$slice, + Variadic:$subcomponent, + Variadic:$indices, + Variadic:$lenParams + ); + let results = (outs fir_ReferenceType); + + let assemblyFormat = [{ + $memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)? + (`path` $subcomponent^)? `<`$indices`>` (`typeparams` $lenParams^)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + unsigned getRank(); + }]; +} + +#endif Index: flang/include/flang/Optimizer/CodeGen/CGPasses.td =================================================================== --- flang/include/flang/Optimizer/CodeGen/CGPasses.td +++ flang/include/flang/Optimizer/CodeGen/CGPasses.td @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef FLANG_OPTIMIZER_CODEGEN_PASSES -#define FLANG_OPTIMIZER_CODEGEN_PASSES +#ifndef FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES +#define FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES include "mlir/Pass/PassBase.td" @@ -22,7 +22,10 @@ Fuse specific subgraphs into single Ops for code generation. }]; let constructor = "fir::createFirCodeGenRewritePass()"; - let dependentDialects = ["fir::FIROpsDialect"]; + let dependentDialects = [ + "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::BuiltinDialect", + "mlir::LLVM::LLVMDialect", "mlir::omp::OpenMPDialect" + ]; } -#endif // FLANG_OPTIMIZER_CODEGEN_PASSES +#endif // FORTRAN_OPTIMIZER_CODEGEN_FIR_PASSES Index: flang/include/flang/Optimizer/CodeGen/CMakeLists.txt =================================================================== --- flang/include/flang/Optimizer/CodeGen/CMakeLists.txt +++ flang/include/flang/Optimizer/CodeGen/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS CGOps.td) +mlir_tablegen(CGOps.h.inc -gen-op-decls) +mlir_tablegen(CGOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(CGOpsIncGen) set(LLVM_TARGET_DEFINITIONS CGPasses.td) mlir_tablegen(CGPasses.h.inc -gen-pass-decls -name OptCodeGen) Index: flang/include/flang/Optimizer/Dialect/FIRDialect.h =================================================================== --- flang/include/flang/Optimizer/Dialect/FIRDialect.h +++ flang/include/flang/Optimizer/Dialect/FIRDialect.h @@ -34,6 +34,16 @@ mlir::DialectAsmPrinter &p) const override; }; +/// The FIR codegen dialect is a dialect containing a small set of transient +/// operations used exclusively during code generation. +class FIRCodeGenDialect final : public mlir::Dialect { +public: + explicit FIRCodeGenDialect(mlir::MLIRContext *ctx); + virtual ~FIRCodeGenDialect(); + + static llvm::StringRef getDialectNamespace() { return "fircg"; } +}; + } // namespace fir #endif // FORTRAN_OPTIMIZER_DIALECT_FIRDIALECT_H Index: flang/include/flang/Optimizer/OptPasses.h =================================================================== --- /dev/null +++ flang/include/flang/Optimizer/OptPasses.h @@ -0,0 +1,20 @@ +//===-- Optimizer/Transforms/Passes.h ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_OPTIMIZER_OPTPASSES_H +#define FORTRAN_OPTIMIZER_OPTPASSES_H + +#include "flang/Optimizer/CodeGen/CodeGen.h" + +namespace fir { +inline void registerOptPasses() { + registerOptCodeGenPasses(); +} +} // namespace fir + +#endif // FORTRAN_OPTIMIZER_OPTPASSES_H Index: flang/include/flang/Optimizer/Support/InitFIR.h =================================================================== --- flang/include/flang/Optimizer/Support/InitFIR.h +++ flang/include/flang/Optimizer/Support/InitFIR.h @@ -26,10 +26,10 @@ // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ - mlir::AffineDialect, FIROpsDialect, mlir::LLVM::LLVMDialect, \ - mlir::acc::OpenACCDialect, mlir::omp::OpenMPDialect, \ - mlir::scf::SCFDialect, mlir::StandardOpsDialect, \ - mlir::vector::VectorDialect + mlir::AffineDialect, FIROpsDialect, FIRCodeGenDialect, \ + mlir::LLVM::LLVMDialect, mlir::acc::OpenACCDialect, \ + mlir::omp::OpenMPDialect, mlir::scf::SCFDialect, \ + mlir::StandardOpsDialect, mlir::vector::VectorDialect /// Register all the dialects used by flang. inline void registerDialects(mlir::DialectRegistry ®istry) { @@ -71,6 +71,9 @@ mlir::registerConvertAffineToStandardPass(); } +/// Register the interfaces needed to lower to LLVM IR. +void registerLLVMTranslation(mlir::MLIRContext &context); + } // namespace fir::support #endif // FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H Index: flang/lib/Optimizer/CMakeLists.txt =================================================================== --- flang/lib/Optimizer/CMakeLists.txt +++ flang/lib/Optimizer/CMakeLists.txt @@ -10,11 +10,16 @@ Support/InternalNames.cpp Support/KindMapping.cpp + CodeGen/CGOps.cpp + CodeGen/PreCGRewrite.cpp + Transforms/Inliner.cpp DEPENDS FIROpsIncGen + FIROptCodeGenPassIncGen FIROptTransformsPassIncGen + CGOpsIncGen ${dialect_libs} LINK_LIBS Index: flang/lib/Optimizer/CodeGen/CGOps.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.h @@ -0,0 +1,24 @@ +//===-- CGOps.h -------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#ifndef OPTIMIZER_CODEGEN_CGOPS_H +#define OPTIMIZER_CODEGEN_CGOPS_H + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "flang/Optimizer/Dialect/FIRType.h" + +using namespace mlir; + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.h.inc" + +#endif Index: flang/lib/Optimizer/CodeGen/CGOps.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/CGOps.cpp @@ -0,0 +1,64 @@ +//===-- CGOps.cpp -- FIR codegen operations -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" + +/// FIR codegen dialect constructor. +fir::FIRCodeGenDialect::FIRCodeGenDialect(mlir::MLIRContext *ctx) + : mlir::Dialect("fircg", ctx, mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" + >(); +} + +// anchor the class vtable to this compilation unit +fir::FIRCodeGenDialect::~FIRCodeGenDialect() { + // do nothing +} + +#define GET_OP_CLASSES +#include "flang/Optimizer/CodeGen/CGOps.cpp.inc" + +unsigned fir::cg::XEmboxOp::getOutRank() { + if (slice().empty()) + return getRank(); + auto outRank = fir::SliceOp::getOutputRank(slice()); + assert(outRank >= 1); + return outRank; +} + +unsigned fir::cg::XReboxOp::getOutRank() { + if (auto seqTy = + fir::dyn_cast_ptrOrBoxEleTy(getType()).dyn_cast()) + return seqTy.getDimension(); + return 0; +} + +unsigned fir::cg::XReboxOp::getRank() { + if (auto seqTy = fir::dyn_cast_ptrOrBoxEleTy(box().getType()) + .dyn_cast()) + return seqTy.getDimension(); + return 0; +} + +unsigned fir::cg::XArrayCoorOp::getRank() { + auto memrefTy = memref().getType(); + if (memrefTy.isa()) + if (auto seqty = + fir::dyn_cast_ptrOrBoxEleTy(memrefTy).dyn_cast()) + return seqty.getDimension(); + return shape().size(); +} Index: flang/lib/Optimizer/CodeGen/PassDetail.h =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PassDetail.h @@ -0,0 +1,26 @@ +//===- PassDetail.h - Optimizer code gen Pass class details -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef OPTMIZER_CODEGEN_PASSDETAIL_H +#define OPTMIZER_CODEGEN_PASSDETAIL_H + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace fir { + +#define GEN_PASS_CLASSES +#include "flang/Optimizer/CodeGen/CGPasses.h.inc" + +} // namespace fir + +#endif // OPTMIZER_CODEGEN_PASSDETAIL_H Index: flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp =================================================================== --- /dev/null +++ flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp @@ -0,0 +1,303 @@ +//===-- PreCGRewrite.cpp --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "CGOps.h" +#include "PassDetail.h" +#include "flang/Optimizer/CodeGen/CodeGen.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" + +//===----------------------------------------------------------------------===// +// Codegen rewrite: rewriting of subgraphs of ops +//===----------------------------------------------------------------------===// + +using namespace fir; + +#define DEBUG_TYPE "flang-codegen-rewrite" + +static void populateShape(llvm::SmallVectorImpl &vec, + ShapeOp shape) { + vec.append(shape.extents().begin(), shape.extents().end()); +} + +// Operands of fir.shape_shift split into two vectors. +static void populateShapeAndShift(llvm::SmallVectorImpl &shapeVec, + llvm::SmallVectorImpl &shiftVec, + ShapeShiftOp shift) { + auto endIter = shift.pairs().end(); + for (auto i = shift.pairs().begin(); i != endIter;) { + shiftVec.push_back(*i++); + shapeVec.push_back(*i++); + } +} + +static void populateShift(llvm::SmallVectorImpl &vec, + ShiftOp shift) { + vec.append(shift.origins().begin(), shift.origins().end()); +} + +namespace { + +/// Convert fir.embox to the extended form where necessary. +class EmboxConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(EmboxOp embox, + mlir::PatternRewriter &rewriter) const override { + auto shapeVal = embox.getShape(); + // If the embox does not include a shape, then do not convert it + if (shapeVal) + return rewriteDynamicShape(embox, rewriter, shapeVal); + if (auto boxTy = embox.getType().dyn_cast()) + if (auto seqTy = boxTy.getEleTy().dyn_cast()) + if (seqTy.hasConstantShape()) + return rewriteStaticShape(embox, rewriter, seqTy); + return mlir::failure(); + } + + mlir::LogicalResult rewriteStaticShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + SequenceType seqTy) const { + auto loc = embox.getLoc(); + llvm::SmallVector shapeOpers; + auto idxTy = rewriter.getIndexType(); + for (auto ext : seqTy.getShape()) { + auto iAttr = rewriter.getIndexAttr(ext); + auto extVal = rewriter.create(loc, idxTy, iAttr); + shapeOpers.push_back(extVal); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, llvm::None, + llvm::None, llvm::None, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } + + mlir::LogicalResult rewriteDynamicShape(EmboxOp embox, + mlir::PatternRewriter &rewriter, + mlir::Value shapeVal) const { + auto loc = embox.getLoc(); + auto shapeOp = dyn_cast(shapeVal.getDefiningOp()); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (shapeOp) { + populateShape(shapeOpers, shapeOp); + } else { + auto shiftOp = dyn_cast(shapeVal.getDefiningOp()); + assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = embox.getSlice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xbox = rewriter.create( + loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, embox.lenParams()); + LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); + rewriter.replaceOp(embox, xbox.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert all fir.array_coor to the extended form. +class ReboxConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ReboxOp rebox, + mlir::PatternRewriter &rewriter) const override { + auto loc = rebox.getLoc(); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (auto shapeVal = rebox.shape()) { + if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + populateShape(shapeOpers, shapeOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShift(shiftOpers, shiftOp); + else + return mlir::failure(); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = rebox.slice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + + auto xRebox = rewriter.create( + loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers, + subcompOpers); + LLVM_DEBUG(llvm::dbgs() + << "rewriting " << rebox << " to " << xRebox << '\n'); + rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert all fir.array_coor to the extended form. +class ArrayCoorConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(ArrayCoorOp arrCoor, + mlir::PatternRewriter &rewriter) const override { + auto loc = arrCoor.getLoc(); + llvm::SmallVector shapeOpers; + llvm::SmallVector shiftOpers; + if (auto shapeVal = arrCoor.shape()) { + if (auto shapeOp = dyn_cast(shapeVal.getDefiningOp())) + populateShape(shapeOpers, shapeOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); + else if (auto shiftOp = dyn_cast(shapeVal.getDefiningOp())) + populateShift(shiftOpers, shiftOp); + else + return mlir::failure(); + } + llvm::SmallVector sliceOpers; + llvm::SmallVector subcompOpers; + if (auto s = arrCoor.slice()) + if (auto sliceOp = dyn_cast_or_null(s.getDefiningOp())) { + sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end()); + subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end()); + } + auto xArrCoor = rewriter.create( + loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers, + sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams()); + LLVM_DEBUG(llvm::dbgs() + << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); + rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); + return mlir::success(); + } +}; + +/// Convert FIR structured control flow ops to CFG ops. +class CodeGenRewrite : public CodeGenRewriteBase { +public: + void runOn(mlir::Operation *op, mlir::Region ®ion) { + auto &context = getContext(); + mlir::OpBuilder rewriter(&context); + mlir::ConversionTarget target(context); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](EmboxOp embox) { + return !(embox.getShape() || + embox.getType().cast().getEleTy().isa()); + }); + mlir::OwningRewritePatternList patterns; + patterns.insert( + &context); + if (mlir::failed( + mlir::applyPartialConversion(op, target, std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(&context), + "error in running the pre-codegen conversions"); + signalPassFailure(); + } + // Erase any residual. + simplifyRegion(region); + } + + void runOnOperation() override final { + // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp. + auto mod = getOperation(); + for (auto func : mod.getOps()) + runOn(func, func.getBody()); + for (auto global : mod.getOps()) + runOn(global, global.getRegion()); + } + + // Clean up the region. + void simplifyRegion(mlir::Region ®ion) { + for (auto &block : region.getBlocks()) + for (auto &op : block.getOperations()) { + if (op.getNumRegions() != 0) + for (auto ® : op.getRegions()) + simplifyRegion(reg); + maybeEraseOp(&op); + } + + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + } + + void maybeEraseOp(mlir::Operation *op) { + if (!op) + return; + + // Erase any embox that was replaced. + if (auto embox = dyn_cast(op)) + if (embox.getShape()) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + + // Erase all fir.array_coor. + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + + // Erase all fir.rebox. + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + + // Erase all fir.shape, fir.shape_shift, fir.shift, and fir.slice ops. + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + if (isa(op)) { + assert(op->use_empty()); + opsToErase.push_back(op); + } + } + +private: + std::vector opsToErase; +}; + +} // namespace + +/// Convert FIR's structured control flow ops to CFG ops. This conversion +/// enables the `createLowerToCFGPass` to transform these to CFG form. +std::unique_ptr fir::createFirCodeGenRewritePass() { + return std::make_unique(); +} Index: flang/test/Fir/cg-ops.fir =================================================================== --- /dev/null +++ flang/test/Fir/cg-ops.fir @@ -0,0 +1,17 @@ +// RUN: fir-opt --cg-rewrite %s | FileCheck %s + +// CHECK-LABEL: func @codegen( +// CHECK-SAME: %[[arg:.*]]: !fir +func @codegen(%addr : !fir.ref>) { + // CHECK: %[[zero:.*]] = constant 0 : index + %0 = constant 0 : index + %1 = fir.shape_shift %0, %0 : (index, index) -> !fir.shapeshift<1> + %2 = fir.slice %0, %0, %0 : (index, index, index) -> !fir.slice<1> + // CHECK: %[[box:.*]] = fircg.ext_embox %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]] : (!fir.ref>, index, index, index, index, index) -> !fir.box> + %3 = fir.embox %addr (%1) [%2] : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> + // CHECK: fircg.ext_array_coor %[[arg]](%[[zero]]) origin %[[zero]][%[[zero]], %[[zero]], %[[zero]]]<%[[zero]]> : (!fir.ref>, index, index, index, index, index, index) -> !fir.ref + %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref + // CHECK: fircg.ext_rebox %[[box]](%[[zero]]) origin %[[zero]] : (!fir.box>, index, index) -> !fir.box> + %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> !fir.box> + return +} Index: flang/tools/fir-opt/fir-opt.cpp =================================================================== --- flang/tools/fir-opt/fir-opt.cpp +++ flang/tools/fir-opt/fir-opt.cpp @@ -12,14 +12,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/MlirOptMain.h" +#include "flang/Optimizer/OptPasses.h" #include "flang/Optimizer/Support/InitFIR.h" using namespace mlir; int main(int argc, char **argv) { fir::support::registerFIRPasses(); + fir::registerOptPasses(); DialectRegistry registry; - fir::support::registerDialects(registry); + registerAllDialects(registry); + registry.insert(); return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n", - registry, /*preloadDialectsInContext*/ false)); + registry, /*preloadDialectsInContext=*/false)); }