diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -267,6 +267,13 @@ /// Returns null on error. mlir::Type applyPathToType(mlir::Type rootTy, mlir::ValueRange path); +/// Does this type encountered in RESULT of a function requires binding the +/// result value with a storage in a fir.save_result operation in order to use +/// the result? +inline bool isAbstractResult(mlir::Type resultType) { + return resultType.isa(); +} + /// Does this function type has a result that requires binding the result value /// with a storage in a fir.save_result operation in order to use the result? bool hasAbstractResult(mlir::FunctionType ty); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::ModuleOp"> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -181,7 +181,7 @@ #if !defined(FLANG_EXCLUDE_CODEGEN) inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); - pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addPass(fir::createAbstractResultOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -937,8 +937,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) { if (ty.getNumResults() == 0) return false; - auto resultType = ty.getResult(0); - return resultType.isa(); + return isAbstractResult(ty.getResult(0)); } /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -18,6 +18,9 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" +#include +#include +#include #define DEBUG_TYPE "flang-abstract-result-opt" @@ -133,14 +136,33 @@ } }; -class ReturnOpConversion : public mlir::OpRewritePattern { +class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) - : OpRewritePattern(context), newArg{newArg} {} + AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult - matchAndRewrite(mlir::func::ReturnOp ret, + matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { + auto oldFuncTy = addrOf.getType().cast(); + auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); + auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, + addrOf.getSymbol()); + // Rather than converting all op a function pointer might transit through + // (e.g calls, stores, loads, converts...), cast new type to the abstract + // type. A conversion will be added when calling indirect calls of abstract + // types. + rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); + return mlir::success(); + } + +private: + bool shouldBoxResult; +}; + +class FuncOpConversion : public mlir::OpRewritePattern { + void rewriteReturn(mlir::func::ReturnOp ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) const { rewriter.setInsertionPoint(ret); auto returnedValue = ret.getOperand(0); bool replacedStorage = false; @@ -160,71 +182,67 @@ if (!replacedStorage) rewriter.create(ret.getLoc(), returnedValue, newArg); rewriter.replaceOpWithNewOp(ret); - return mlir::success(); } -private: - mlir::Value newArg; -}; - -class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + FuncOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult - matchAndRewrite(fir::AddrOfOp addrOf, + matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { - auto oldFuncTy = addrOf.getType().cast(); - auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); - auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, - addrOf.getSymbol()); - // Rather than converting all op a function pointer might transit through - // (e.g calls, stores, loads, converts...), cast new type to the abstract - // type. A conversion will be added when calling indirect calls of abstract - // types. - rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); - return mlir::success(); - } - -private: - bool shouldBoxResult; -}; - -class AbstractResultOpt : public fir::AbstractResultOptBase { -public: - void runOnOperation() override { - auto *context = &getContext(); - auto func = getOperation(); + auto *context = getContext(); auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - const bool shouldBoxResult = passResultAsBox.getValue(); // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, shouldBoxResult)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument + rewriter.updateRootInPlace(func, [&] { + func.setType(getNewFunctionType(funcTy, shouldBoxResult)); + }); + + if (func.empty()) { + // If a function has no body, its conversion is done + return mlir::success(); + } + // Insert new argument + mlir::Value newArg; + rewriter.updateRootInPlace(func, [&]() mutable { mlir::OpBuilder rewriter(context); auto resultType = funcTy.getResult(0); auto argTy = getResultArgumentType(resultType, shouldBoxResult); - mlir::Value newArg = func.front().insertArgument(zero, argTy, loc); + newArg = func.front().insertArgument(0u, argTy, loc); if (mustEmboxResult(resultType, shouldBoxResult)) { auto bufferType = fir::ReferenceType::get(resultType); rewriter.setInsertionPointToStart(&func.front()); newArg = rewriter.create(loc, bufferType, newArg); } - patterns.insert(context, newArg); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); - } + }); + + func.walk([&, this](mlir::Operation *op) { + if (auto ret = mlir::dyn_cast(op)) { + rewriteReturn(ret, newArg, rewriter); + } + }); } - if (func.empty()) - return; + return mlir::success(); + } + +private: + bool shouldBoxResult; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto module = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target = *context; + const bool shouldBoxResult = passResultAsBox.getValue(); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect( + [](mlir::func::FuncOp func) { + return !hasAbstractResult(func.getFunctionType()); + }); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { + return ret.operands().empty() || + !isAbstractResult(ret.operands().front().getType()); + }); + patterns.insert(context, shouldBoxResult); patterns.insert(context, shouldBoxResult); patterns.insert(context); patterns.insert(context, shouldBoxResult); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + if (mlir::failed(mlir::applyPartialConversion(module, target, + std::move(patterns)))) { + mlir::emitError(module.getLoc(), + "error in converting abstract results\n"); signalPassFailure(); } } diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -25,8 +25,6 @@ ! CHECK: SimplifyRegionLite ! CHECK: CSE ! CHECK: BoxedProcedurePass - -! CHECK-LABEL: 'func.func' Pipeline ! CHECK: AbstractResultOpt ! CHECK: CodeGenRewrite ! CHECK: TargetRewrite diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -34,8 +34,6 @@ // PASSES: SimplifyRegionLite // PASSES: CSE // PASSES: BoxedProcedurePass - -// PASSES-LABEL: 'func.func' Pipeline // PASSES: AbstractResultOpt // PASSES: CodeGenRewrite // PASSES: TargetRewrite