diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -267,6 +267,13 @@ /// Returns null on error. mlir::Type applyPathToType(mlir::Type rootTy, mlir::ValueRange path); +/// Does this type encountered in RESULT of a function requires binding the +/// result value with a storage in a fir.save_result operation in order to use +/// the result? +inline bool isAbstractResult(mlir::Type resultType) { + return resultType.isa(); +} + /// Does this function type has a result that requires binding the result value /// with a storage in a fir.save_result operation in order to use the result? bool hasAbstractResult(mlir::FunctionType ty); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::ModuleOp"> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -181,7 +181,7 @@ #if !defined(FLANG_EXCLUDE_CODEGEN) inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); - pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addPass(fir::createAbstractResultOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -937,8 +937,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) { if (ty.getNumResults() == 0) return false; - auto resultType = ty.getResult(0); - return resultType.isa(); + return isAbstractResult(ty.getResult(0)); } /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -18,26 +18,21 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" +#include +#include +#include #define DEBUG_TYPE "flang-abstract-result-opt" namespace fir { namespace { -struct AbstractResultOptions { - // Always pass result as a fir.box argument. - bool boxResult = false; - // New function block argument for the result if the current FuncOp had - // an abstract result. - mlir::Value newArg; -}; - static mlir::Type getResultArgumentType(mlir::Type resultType, - const AbstractResultOptions &options) { + bool shouldBoxResult) { return llvm::TypeSwitch(resultType) .Case( [&](mlir::Type type) -> mlir::Type { - if (options.boxResult) + if (shouldBoxResult) return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) @@ -49,28 +44,26 @@ }); } -static mlir::FunctionType -getNewFunctionType(mlir::FunctionType funcTy, - const AbstractResultOptions &options) { +static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, + bool shouldBoxResult) { auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); llvm::SmallVector newInputTypes = {argTy}; newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, /*resultTypes=*/{}); } -static bool mustEmboxResult(mlir::Type resultType, - const AbstractResultOptions &options) { +static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { return resultType.isa() && - options.boxResult; + shouldBoxResult; } class CallOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { @@ -88,10 +81,10 @@ loc, "calls with abstract result must be used in fir.save_result"); return mlir::failure(); } - auto argType = getResultArgumentType(result.getType(), options); + auto argType = getResultArgumentType(result.getType(), shouldBoxResult); auto buffer = saveResult.getMemref(); mlir::Value arg = buffer; - if (mustEmboxResult(result.getType(), options)) + if (mustEmboxResult(result.getType(), shouldBoxResult)) arg = rewriter.create( loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, saveResult.getTypeparams()); @@ -126,7 +119,7 @@ } private: - const AbstractResultOptions &options; + bool shouldBoxResult; }; class SaveResultOpConversion @@ -143,22 +136,40 @@ } }; -class ReturnOpConversion : public mlir::OpRewritePattern { +class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ReturnOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult - matchAndRewrite(mlir::func::ReturnOp ret, + matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { + auto oldFuncTy = addrOf.getType().cast(); + auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); + auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, + addrOf.getSymbol()); + // Rather than converting all op a function pointer might transit through + // (e.g calls, stores, loads, converts...), cast new type to the abstract + // type. A conversion will be added when calling indirect calls of abstract + // types. + rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); + return mlir::success(); + } + +private: + bool shouldBoxResult; +}; + +class FuncOpConversion : public mlir::OpRewritePattern { + void rewriteReturn(mlir::func::ReturnOp ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) const { rewriter.setInsertionPoint(ret); auto returnedValue = ret.getOperand(0); bool replacedStorage = false; if (auto *op = returnedValue.getDefiningOp()) if (auto load = mlir::dyn_cast(op)) { auto resultStorage = load.getMemref(); - load.getMemref().replaceAllUsesWith(options.newArg); + load.getMemref().replaceAllUsesWith(newArg); replacedStorage = true; if (auto *alloc = resultStorage.getDefiningOp()) if (alloc->use_empty()) @@ -169,77 +180,69 @@ // with no length parameters. Simply store the result in the result storage. // at the return point. if (!replacedStorage) - rewriter.create(ret.getLoc(), returnedValue, - options.newArg); + rewriter.create(ret.getLoc(), returnedValue, newArg); rewriter.replaceOpWithNewOp(ret); - return mlir::success(); } -private: - const AbstractResultOptions &options; -}; - -class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - AddrOfOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + FuncOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult - matchAndRewrite(fir::AddrOfOp addrOf, + matchAndRewrite(mlir::func::FuncOp func, mlir::PatternRewriter &rewriter) const override { - auto oldFuncTy = addrOf.getType().cast(); - auto newFuncTy = getNewFunctionType(oldFuncTy, options); - auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, - addrOf.getSymbol()); - // Rather than converting all op a function pointer might transit through - // (e.g calls, stores, loads, converts...), cast new type to the abstract - // type. A conversion will be added when calling indirect calls of abstract - // types. - rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); - return mlir::success(); - } - -private: - const AbstractResultOptions &options; -}; - -class AbstractResultOpt : public fir::AbstractResultOptBase { -public: - void runOnOperation() override { - auto *context = &getContext(); - auto func = getOperation(); + auto *context = getContext(); auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - AbstractResultOptions options{passResultAsBox.getValue(), - /*newArg=*/{}}; // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, options)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument + rewriter.updateRootInPlace(func, [&] { + func.setType(getNewFunctionType(funcTy, shouldBoxResult)); + }); + + if (func.empty()) { + // If a function has no body, its conversion is done + return mlir::success(); + } + // Insert new argument + mlir::Value newArg; + rewriter.updateRootInPlace(func, [&]() mutable { mlir::OpBuilder rewriter(context); auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); - options.newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, options)) { + auto argTy = getResultArgumentType(resultType, shouldBoxResult); + newArg = func.front().insertArgument(0u, argTy, loc); + if (mustEmboxResult(resultType, shouldBoxResult)) { auto bufferType = fir::ReferenceType::get(resultType); rewriter.setInsertionPointToStart(&func.front()); - options.newArg = - rewriter.create(loc, bufferType, options.newArg); + newArg = rewriter.create(loc, bufferType, newArg); } - patterns.insert(context, options); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); - } + }); + + func.walk([&, this](mlir::Operation *op) { + if (auto ret = mlir::dyn_cast(op)) { + rewriteReturn(ret, newArg, rewriter); + } + }); } - if (func.empty()) - return; + return mlir::success(); + } + +private: + bool shouldBoxResult; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto module = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target = *context; + const bool shouldBoxResult = passResultAsBox.getValue(); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect( + [](mlir::func::FuncOp func) { + return !hasAbstractResult(func.getFunctionType()); + }); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { + return ret.operands().empty() || + !isAbstractResult(ret.operands().front().getType()); + }); - patterns.insert(context, options); + patterns.insert(context, shouldBoxResult); + patterns.insert(context, shouldBoxResult); patterns.insert(context); - patterns.insert(context, options); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + patterns.insert(context, shouldBoxResult); + if (mlir::failed(mlir::applyPartialConversion(module, target, + std::move(patterns)))) { + mlir::emitError(module.getLoc(), + "error in converting abstract results\n"); signalPassFailure(); } } diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -25,8 +25,6 @@ ! CHECK: SimplifyRegionLite ! CHECK: CSE ! CHECK: BoxedProcedurePass - -! CHECK-LABEL: 'func.func' Pipeline ! CHECK: AbstractResultOpt ! CHECK: CodeGenRewrite ! CHECK: TargetRewrite diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -34,8 +34,6 @@ // PASSES: SimplifyRegionLite // PASSES: CSE // PASSES: BoxedProcedurePass - -// PASSES-LABEL: 'func.func' Pipeline // PASSES: AbstractResultOpt // PASSES: CodeGenRewrite // PASSES: TargetRewrite