diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -26,7 +26,7 @@ // Passes defined in Passes.td //===----------------------------------------------------------------------===// -std::unique_ptr createAbstractResultOptPass(); +std::unique_ptr createAbstractResultOnFuncOptPass(); std::unique_ptr createAffineDemotionPass(); std::unique_ptr createArrayValueCopyPass(); std::unique_ptr createFirToCfgPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,14 +16,14 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +class AbstractResultOptBase + : Pass<"abstract-result-on-" # optExt # "-opt", operation> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ This pass is required before code gen to the LLVM IR dialect, including the pre-cg rewrite pass. }]; - let constructor = "::fir::createAbstractResultOptPass()"; let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect" ]; @@ -35,6 +35,10 @@ ]; } +def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> { + let constructor = "::fir::createAbstractResultOnFuncOptPass()"; +} + def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> { let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`."; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -191,7 +191,8 @@ #if !defined(FLANG_EXCLUDE_CODEGEN) inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); - pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addNestedPass( + fir::createAbstractResultOnFuncOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -191,40 +191,26 @@ bool shouldBoxResult; }; -class AbstractResultOpt : public fir::AbstractResultOptBase { +/// @brief Base CRTP class for AbstractResult pass family. +/// Contains common logic for abstract result conversion in a reusable fashion. +/// @tparam Pass target class that implements operation-specific logic. +/// @tparam PassBase base class template for the pass generated by TableGen. +/// The `Pass` class must define runOnSpecificOperation(OpTy, bool, +/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function. +/// This function should implement operation-specific functionality. +template class PassBase> +class AbstractResultOptTemplate : public PassBase { public: void runOnOperation() override { - auto *context = &getContext(); - auto func = getOperation(); - auto loc = func.getLoc(); + auto *context = &this->getContext(); + auto op = this->getOperation(); + mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - const bool shouldBoxResult = passResultAsBox.getValue(); - - // Convert function type itself if it has an abstract result - auto funcTy = func.getFunctionType().cast(); - if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, shouldBoxResult)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument - mlir::OpBuilder rewriter(context); - auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, shouldBoxResult); - mlir::Value newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, shouldBoxResult)) { - auto bufferType = fir::ReferenceType::get(resultType); - rewriter.setInsertionPointToStart(&func.front()); - newArg = rewriter.create(loc, bufferType, newArg); - } - patterns.insert(context, newArg); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); - } - } + const bool shouldBoxResult = this->passResultAsBox.getValue(); - if (func.empty()) - return; + auto &self = static_cast(*this); + self.runOnSpecificOperation(op, shouldBoxResult, patterns, target); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(context); patterns.insert(context, shouldBoxResult); if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); - signalPassFailure(); + mlir::applyPartialConversion(op, target, std::move(patterns)))) { + mlir::emitError(op.getLoc(), "error in converting abstract results\n"); + this->signalPassFailure(); + } + } +}; + +class AbstractResultOnFuncOpt + : public AbstractResultOptTemplate { +public: + void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult, + mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target) { + auto loc = func.getLoc(); + auto *context = &getContext(); + // Convert function type itself if it has an abstract result. + auto funcTy = func.getFunctionType().cast(); + if (hasAbstractResult(funcTy)) { + func.setType(getNewFunctionType(funcTy, shouldBoxResult)); + if (!func.empty()) { + // Insert new argument. + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); + mlir::Value newArg = func.front().insertArgument(0u, argTy, loc); + if (mustEmboxResult(resultType, shouldBoxResult)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + newArg = rewriter.create(loc, bufferType, newArg); + } + patterns.insert(context, newArg); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); + } } } }; } // end anonymous namespace } // namespace fir -std::unique_ptr fir::createAbstractResultOptPass() { - return std::make_unique(); +std::unique_ptr fir::createAbstractResultOnFuncOptPass() { + return std::make_unique(); } diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -52,7 +52,7 @@ ! ALL-NEXT: BoxedProcedurePass ! ALL-NEXT: 'func.func' Pipeline -! ALL-NEXT: AbstractResultOpt +! ALL-NEXT: AbstractResultOnFuncOpt ! ALL-NEXT: CodeGenRewrite ! ALL-NEXT: (S) 0 num-dce'd - Number of operations eliminated diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -1,8 +1,8 @@ // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to // functions that take an additional argument for the result. -// RUN: fir-opt %s --abstract-result-opt | FileCheck %s -// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX +// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s +// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX // ----------------------- Test declaration rewrite ---------------------------- diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -52,7 +52,7 @@ // PASSES-NEXT: BoxedProcedurePass // PASSES-NEXT: 'func.func' Pipeline -// PASSES-NEXT: AbstractResultOpt +// PASSES-NEXT: AbstractResultOnFuncOpt // PASSES-NEXT: CodeGenRewrite // PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated