diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" -def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> { +def AbstractResultOpt : Pass<"abstract-result-opt"> { let summary = "Convert fir.array, fir.box and fir.rec function result to " "function argument"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -15,6 +15,7 @@ #include "mlir/Transforms/Passes.h" #include "flang/Optimizer/CodeGen/CodeGen.h" #include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Optimizer/Dialect/FIROps.h" #include "llvm/Support/CommandLine.h" #define DisableOption(DOName, DOOption, DODescription) \ @@ -182,6 +183,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) { fir::addBoxedProcedurePass(pm); pm.addNestedPass(fir::createAbstractResultOptPass()); + pm.addNestedPass(fir::createAbstractResultOptPass()); fir::addCodeGenRewritePass(pm); fir::addTargetRewritePass(pm); fir::addExternalNameConversionPass(pm); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -18,6 +18,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/TypeSwitch.h" +#include #define DEBUG_TYPE "flang-abstract-result-opt" @@ -205,43 +206,66 @@ }; class AbstractResultOpt : public fir::AbstractResultOptBase { + /// Returns true if should run conversions, e. g. function body is not empty + bool runOnFuncOp(mlir::MLIRContext *context, mlir::Operation *op, + mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target, + AbstractResultOptions &options) const { + if (auto func = mlir::dyn_cast(op)) { + const auto loc = func.getLoc(); + // Convert function type itself if it has an abstract result + auto funcTy = func.getFunctionType().cast(); + if (hasAbstractResult(funcTy)) { + func.setType(getNewFunctionType(funcTy, options)); + unsigned zero = 0; + if (!func.empty()) { + // Insert new argument + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, options); + options.newArg = func.front().insertArgument(zero, argTy, loc); + if (mustEmboxResult(resultType, options)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + options.newArg = rewriter.create(loc, bufferType, + options.newArg); + } + patterns.insert(context, options); + target.addDynamicallyLegalOp( + [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); + } + } + return !func.empty(); + } + // Only GlobalOp can reach here + return true; + } + + bool shouldRunConversionsOn(mlir::Operation *op) const { + return mlir::dyn_cast(op) || + mlir::dyn_cast(op); + } + public: void runOnOperation() override { auto *context = &getContext(); - auto func = getOperation(); - auto loc = func.getLoc(); + auto *op = getOperation(); + // llvm::errs() << op->getName() << '\n'; mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - AbstractResultOptions options{passResultAsBox.getValue(), - /*newArg=*/{}}; + AbstractResultOptions options{passResultAsBox.getValue(), /*newArg=*/{}}; - // Convert function type itself if it has an abstract result - auto funcTy = func.getFunctionType().cast(); - if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, options)); - unsigned zero = 0; - if (!func.empty()) { - // Insert new argument - mlir::OpBuilder rewriter(context); - auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); - options.newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, options)) { - auto bufferType = fir::ReferenceType::get(resultType); - rewriter.setInsertionPointToStart(&func.front()); - options.newArg = - rewriter.create(loc, bufferType, options.newArg); - } - patterns.insert(context, options); - target.addDynamicallyLegalOp( - [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); - } - } + // if (!shouldRunConversionsOn(op)) { + // return; + // } - if (func.empty()) + if (const bool shouldContinue = + runOnFuncOp(context, op, patterns, target, options); + !shouldContinue) { return; + } - // Convert the calls and, if needed, the ReturnOp in the function body. + // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect(); target.addIllegalOp(); @@ -267,9 +291,11 @@ patterns.insert(context, options); patterns.insert(context); patterns.insert(context, options); - if (mlir::failed( - mlir::applyPartialConversion(func, target, std::move(patterns)))) { - mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + + if (const auto conversionResult = + mlir::applyPartialConversion(op, target, std::move(patterns)); + mlir::failed(conversionResult)) { + mlir::emitError(op->getLoc(), "error in converting abstract results\n"); signalPassFailure(); } }