diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::ModuleOp"> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -181,7 +181,7 @@ #if !defined(FLANG_EXCLUDE_CODEGEN) inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); - pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addPass(fir::createAbstractResultOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -191,40 +191,57 @@ bool shouldBoxResult; }; -class AbstractResultOpt : public fir::AbstractResultOptBase { +class FuncOpConversion : public mlir::OpRewritePattern { public: - void runOnOperation() override { - auto *context = &getContext(); - auto func = getOperation(); + using OpRewritePattern::OpRewritePattern; + FuncOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} + mlir::LogicalResult + matchAndRewrite(mlir::func::FuncOp func, + mlir::PatternRewriter &rewriter) const override { + auto *context = getContext(); auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - const bool shouldBoxResult = passResultAsBox.getValue(); // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { func.setType(getNewFunctionType(funcTy, shouldBoxResult)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument - mlir::OpBuilder rewriter(context); - auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, shouldBoxResult); - mlir::Value newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, shouldBoxResult)) { - auto bufferType = fir::ReferenceType::get(resultType); - rewriter.setInsertionPointToStart(&func.front()); - newArg = rewriter.create(loc, bufferType, newArg); - } - patterns.insert(context, newArg); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); + if (func.empty()) { + // If a function has no body, its conversion is done + return mlir::success(); + } + // Insert new argument + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); + mlir::Value newArg = func.front().insertArgument(0u, argTy, loc); + if (mustEmboxResult(resultType, shouldBoxResult)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + newArg = rewriter.create(loc, bufferType, newArg); } + patterns.insert(context, newArg); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); } - if (func.empty()) - return; + return mlir::applyPartialConversion(func, target, std::move(patterns)); + } + +private: + bool shouldBoxResult; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto module = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target = *context; + const bool shouldBoxResult = passResultAsBox.getValue(); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(context, shouldBoxResult); + patterns.insert(context, shouldBoxResult); patterns.insert(context); patterns.insert(context, shouldBoxResult); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + if (mlir::failed(mlir::applyPartialConversion(module, target, + std::move(patterns)))) { + mlir::emitError(module.getLoc(), + "error in converting abstract results\n"); signalPassFailure(); } } diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -25,8 +25,6 @@ ! CHECK: SimplifyRegionLite ! CHECK: CSE ! CHECK: BoxedProcedurePass - -! CHECK-LABEL: 'func.func' Pipeline ! CHECK: AbstractResultOpt ! CHECK: CodeGenRewrite ! CHECK: TargetRewrite diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -34,8 +34,6 @@ // PASSES: SimplifyRegionLite // PASSES: CSE // PASSES: BoxedProcedurePass - -// PASSES-LABEL: 'func.func' Pipeline // PASSES: AbstractResultOpt // PASSES: CodeGenRewrite // PASSES: TargetRewrite