diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -24,20 +24,12 @@ namespace fir { namespace { -struct AbstractResultOptions { - // Always pass result as a fir.box argument. - bool boxResult = false; - // New function block argument for the result if the current FuncOp had - // an abstract result. - mlir::Value newArg; -}; - static mlir::Type getResultArgumentType(mlir::Type resultType, - const AbstractResultOptions &options) { + bool shouldBoxResult) { return llvm::TypeSwitch(resultType) .Case( [&](mlir::Type type) -> mlir::Type { - if (options.boxResult) + if (shouldBoxResult) return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) @@ -49,28 +41,26 @@ }); } -static mlir::FunctionType -getNewFunctionType(mlir::FunctionType funcTy, - const AbstractResultOptions &options) { +static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, + bool shouldBoxResult) { auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); + auto argTy = getResultArgumentType(resultType, shouldBoxResult); llvm::SmallVector newInputTypes = {argTy}; newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, /*resultTypes=*/{}); } -static bool mustEmboxResult(mlir::Type resultType, - const AbstractResultOptions &options) { +static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { return resultType.isa() && - options.boxResult; + shouldBoxResult; } class CallOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { @@ -88,10 +78,10 @@ loc, "calls with abstract result must be used in fir.save_result"); return mlir::failure(); } - auto argType = getResultArgumentType(result.getType(), options); + auto argType = getResultArgumentType(result.getType(), shouldBoxResult); auto buffer = saveResult.getMemref(); mlir::Value arg = buffer; - if (mustEmboxResult(result.getType(), options)) + if (mustEmboxResult(result.getType(), shouldBoxResult)) arg = rewriter.create( loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, saveResult.getTypeparams()); @@ -126,7 +116,7 @@ } private: - const AbstractResultOptions &options; + bool shouldBoxResult; }; class SaveResultOpConversion @@ -146,9 +136,8 @@ class ReturnOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - ReturnOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { @@ -158,7 +147,7 @@ if (auto *op = returnedValue.getDefiningOp()) if (auto load = mlir::dyn_cast(op)) { auto resultStorage = load.getMemref(); - load.getMemref().replaceAllUsesWith(options.newArg); + load.getMemref().replaceAllUsesWith(newArg); replacedStorage = true; if (auto *alloc = resultStorage.getDefiningOp()) if (alloc->use_empty()) @@ -169,27 +158,25 @@ // with no length parameters. Simply store the result in the result storage. // at the return point. if (!replacedStorage) - rewriter.create(ret.getLoc(), returnedValue, - options.newArg); + rewriter.create(ret.getLoc(), returnedValue, newArg); rewriter.replaceOpWithNewOp(ret); return mlir::success(); } private: - const AbstractResultOptions &options; + mlir::Value newArg; }; class AddrOfOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - AddrOfOpConversion(mlir::MLIRContext *context, - const AbstractResultOptions &opt) - : OpRewritePattern(context), options{opt} {} + AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} mlir::LogicalResult matchAndRewrite(fir::AddrOfOp addrOf, mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = addrOf.getType().cast(); - auto newFuncTy = getNewFunctionType(oldFuncTy, options); + auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, addrOf.getSymbol()); // Rather than converting all op a function pointer might transit through @@ -201,7 +188,7 @@ } private: - const AbstractResultOptions &options; + bool shouldBoxResult; }; class AbstractResultOpt : public fir::AbstractResultOptBase { @@ -212,27 +199,25 @@ auto loc = func.getLoc(); mlir::RewritePatternSet patterns(context); mlir::ConversionTarget target = *context; - AbstractResultOptions options{passResultAsBox.getValue(), - /*newArg=*/{}}; + const bool shouldBoxResult = passResultAsBox.getValue(); // Convert function type itself if it has an abstract result auto funcTy = func.getFunctionType().cast(); if (hasAbstractResult(funcTy)) { - func.setType(getNewFunctionType(funcTy, options)); + func.setType(getNewFunctionType(funcTy, shouldBoxResult)); unsigned zero = 0; if (!func.empty()) { // Insert new argument mlir::OpBuilder rewriter(context); auto resultType = funcTy.getResult(0); - auto argTy = getResultArgumentType(resultType, options); - options.newArg = func.front().insertArgument(zero, argTy, loc); - if (mustEmboxResult(resultType, options)) { + auto argTy = getResultArgumentType(resultType, shouldBoxResult); + mlir::Value newArg = func.front().insertArgument(zero, argTy, loc); + if (mustEmboxResult(resultType, shouldBoxResult)) { auto bufferType = fir::ReferenceType::get(resultType); rewriter.setInsertionPointToStart(&func.front()); - options.newArg = - rewriter.create(loc, bufferType, options.newArg); + newArg = rewriter.create(loc, bufferType, newArg); } - patterns.insert(context, options); + patterns.insert(context, newArg); target.addDynamicallyLegalOp( [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); } @@ -264,9 +249,9 @@ return true; }); - patterns.insert(context, options); + patterns.insert(context, shouldBoxResult); patterns.insert(context); - patterns.insert(context, options); + patterns.insert(context, shouldBoxResult); if (mlir::failed( mlir::applyPartialConversion(func, target, std::move(patterns)))) { mlir::emitError(func.getLoc(), "error in converting abstract results\n");