diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::ModuleOp"> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -181,7 +181,7 @@ #if !defined(FLANG_EXCLUDE_CODEGEN) inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); - pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addPass(fir::createAbstractResultOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -24,20 +24,12 @@ namespace fir { namespace { -struct AbstractResultOptions { - // Always pass result as a fir.box argument. - bool boxResult = false; - // New function block argument for the result if the current FuncOp had - // an abstract result. - mlir::Value newArg; -}; - static mlir::Type getResultArgumentType(mlir::Type resultType, - const AbstractResultOptions &options) { + bool shouldBoxResult) { return llvm::TypeSwitch(resultType) .Case( [&](mlir::Type type) -> mlir::Type { - if (options.boxResult) + if (shouldBoxResult) return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) @@ -49,28 +41,26 @@ }); } -static mlir::FunctionType -getNewFunctionType(mlir::FunctionType funcTy, - const AbstractResultOptions &options) { +static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, + bool shouldBoxResult) { auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); llvm::SmallVector newInputTypes = {argTy}; newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, /*resultTypes=*/{}); } -static bool mustEmboxResult(mlir::Type resultType, - const AbstractResultOptions &options) { +static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { return resultType.isa() && - options.boxResult; + shouldBoxResult; } class CallOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { @@ -88,10 +78,10 @@ loc, "calls with abstract result must be used in fir.save_result"); return mlir::failure(); } - auto argType = getResultArgumentType(result.getType(), options); + auto argType = getResultArgumentType(result.getType(), shouldBoxResult); auto buffer = saveResult.getMemref(); mlir::Value arg = buffer; - if (mustEmboxResult(result.getType(), options)) + if (mustEmboxResult(result.getType(), shouldBoxResult)) arg = rewriter.create( loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, saveResult.getTypeparams()); @@ -126,7 +116,7 @@ } private: - const AbstractResultOptions &options; + bool shouldBoxResult; }; class SaveResultOpConversion @@ -146,9 +136,8 @@ class ReturnOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ReturnOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { @@ -158,7 +147,7 @@ if (auto *op = returnedValue.getDefiningOp()) if (auto load = mlir::dyn_cast(op)) { auto resultStorage = load.getMemref(); - load.getMemref().replaceAllUsesWith(options.newArg); + load.getMemref().replaceAllUsesWith(newArg); replacedStorage = true; if (auto *alloc = resultStorage.getDefiningOp()) if (alloc->use_empty()) @@ -169,27 +158,25 @@ // with no length parameters. Simply store the result in the result storage. // at the return point. if (!replacedStorage) - rewriter.create(ret.getLoc(), returnedValue, - options.newArg); + rewriter.create(ret.getLoc(), returnedValue, newArg); rewriter.replaceOpWithNewOp(ret); return mlir::success(); } private: - const AbstractResultOptions &options; + mlir::Value newArg; }; class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - AddrOfOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = addrOf.getType().cast(); - auto newFuncTy = getNewFunctionType(oldFuncTy, options); + auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, addrOf.getSymbol()); // Rather than converting all op a function pointer might transit through @@ -201,45 +188,60 @@ } private: - const AbstractResultOptions &options; + bool shouldBoxResult; }; -class AbstractResultOpt : public fir::AbstractResultOptBase { +class FuncOpConversion : public mlir::OpRewritePattern { public: - void runOnOperation() override { - auto *context = &getContext(); - auto func = getOperation(); + using OpRewritePattern::OpRewritePattern; + FuncOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} + mlir::LogicalResult + matchAndRewrite(mlir::func::FuncOp func, + mlir::PatternRewriter &rewriter) const override { + auto *context = getContext(); auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - AbstractResultOptions options{passResultAsBox.getValue(), - /*newArg=*/{}}; // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, options)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument - mlir::OpBuilder rewriter(context); - auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); - options.newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, options)) { - auto bufferType = fir::ReferenceType::get(resultType); - rewriter.setInsertionPointToStart(&func.front()); - options.newArg = - rewriter.create(loc, bufferType, options.newArg); - } - patterns.insert(context, options); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); + func.setType(getNewFunctionType(funcTy, shouldBoxResult)); + if (func.empty()) { + // If a function has no body, its conversion is done + return mlir::success(); + } + // Insert new argument + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); + mlir::Value newArg = func.front().insertArgument(0u, argTy, loc); + if (mustEmboxResult(resultType, shouldBoxResult)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + newArg = rewriter.create(loc, bufferType, newArg); } + patterns.insert(context, newArg); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); } - if (func.empty()) - return; + return mlir::applyPartialConversion(func, target, std::move(patterns)); + } + +private: + bool shouldBoxResult; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto module = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target = *context; + const bool shouldBoxResult = passResultAsBox.getValue(); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(context, options); + patterns.insert(context, shouldBoxResult); + patterns.insert(context, shouldBoxResult); patterns.insert(context); - patterns.insert(context, options); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + patterns.insert(context, shouldBoxResult); + if (mlir::failed(mlir::applyPartialConversion(module, target, + std::move(patterns)))) { + mlir::emitError(module.getLoc(), + "error in converting abstract results\n"); signalPassFailure(); } } diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -25,8 +25,6 @@ ! CHECK: SimplifyRegionLite ! CHECK: CSE ! CHECK: BoxedProcedurePass - -! CHECK-LABEL: 'func.func' Pipeline ! CHECK: AbstractResultOpt ! CHECK: CodeGenRewrite ! CHECK: TargetRewrite diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -34,8 +34,6 @@ // PASSES: SimplifyRegionLite // PASSES: CSE // PASSES: BoxedProcedurePass - -// PASSES-LABEL: 'func.func' Pipeline // PASSES: AbstractResultOpt // PASSES: CodeGenRewrite // PASSES: TargetRewrite