diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -26,6 +26,7 @@ // Passes defined in Passes.td //===----------------------------------------------------------------------===// +std::unique_ptr createAbstractResultOptPass(); std::unique_ptr createAffineDemotionPass(); std::unique_ptr createCharacterConversionPass(); std::unique_ptr createExternalNameConversionPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file contains definitions for passes within the Optimizer/Transforms/ -// directory. +// directory. // //===----------------------------------------------------------------------===// @@ -16,6 +16,25 @@ include "mlir/Pass/PassBase.td" +def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::FuncOp"> { + let summary = "Convert fir.array, fir.box and fir.rec function result to " + "function argument"; + let description = [{ + This pass is required before code gen to the LLVM IR dialect, + including the pre-cg rewrite pass. + }]; + let constructor = "::fir::createAbstractResultOptPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::StandardOpsDialect" + ]; + let options = [ + Option<"passResultAsBox", "abstract-result-as-box", + "bool", /*default=*/"false", + "Pass fir.array result as fir.box> argument instead" + " of fir.ref>."> + ]; +} + def AffineDialectPromotion : FunctionPass<"promote-to-affine"> { let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`."; let description = [{ diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -623,7 +623,8 @@ llvm::ArrayRef results, mlir::ValueRange operands) { result.addOperands(operands); - result.addAttribute(getCalleeAttrName(), callee); + if (callee) + result.addAttribute(getCalleeAttrName(), callee); result.addTypes(results); } diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -0,0 +1,288 @@ +//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "flang-abstract-result-opt" + +namespace fir { +namespace { + +struct AbstractResultOptions { + // Always pass result as a fir.box argument. + bool boxResult = false; + // New function block argument for the result if the current FuncOp had + // an abstract result. + mlir::Value newArg; +}; + +static bool mustConvertCallOrFunc(mlir::FunctionType type) { + if (type.getNumResults() == 0) + return false; + auto resultType = type.getResult(0); + return resultType.isa(); +} + +static mlir::Type getResultArgumentType(mlir::Type resultType, + const AbstractResultOptions &options) { + return llvm::TypeSwitch(resultType) + .Case( + [&](mlir::Type type) -> mlir::Type { + if (options.boxResult) + return fir::BoxType::get(type); + return fir::ReferenceType::get(type); + }) + .Case([](mlir::Type type) -> mlir::Type { + return fir::ReferenceType::get(type); + }) + .Default([](mlir::Type) -> mlir::Type { + llvm_unreachable("bad abstract result type"); + }); +} + +static mlir::FunctionType +getNewFunctionType(mlir::FunctionType funcTy, + const AbstractResultOptions &options) { + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, options); + llvm::SmallVector newInputTypes = {argTy}; + newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); + return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, + /*resultTypes=*/{}); +} + +static bool mustEmboxResult(mlir::Type resultType, + const AbstractResultOptions &options) { + return resultType.isa() && + options.boxResult; +} + +class CallOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + auto loc = callOp.getLoc(); + auto result = callOp->getResult(0); + if (!result.hasOneUse()) { + mlir::emitError(loc, + "calls with abstract result must have exactly one user"); + return mlir::failure(); + } + auto saveResult = + mlir::dyn_cast(result.use_begin().getUser()); + if (!saveResult) { + mlir::emitError( + loc, "calls with abstract result must be used in fir.save_result"); + return mlir::failure(); + } + auto argType = getResultArgumentType(result.getType(), options); + auto buffer = saveResult.memref(); + mlir::Value arg = buffer; + if (mustEmboxResult(result.getType(), options)) + arg = rewriter.create( + loc, argType, buffer, saveResult.shape(), /*slice*/ mlir::Value{}, + saveResult.typeparams()); + + llvm::SmallVector newResultTypes; + if (callOp.callee()) { + llvm::SmallVector newOperands = {arg}; + newOperands.append(callOp.getOperands().begin(), + callOp.getOperands().end()); + rewriter.create(loc, callOp.callee().getValue(), + newResultTypes, newOperands); + } else { + // Indirect calls. + llvm::SmallVector newInputTypes = {argType}; + for (auto operand : callOp.getOperands().drop_front()) + newInputTypes.push_back(operand.getType()); + auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, + newResultTypes); + + llvm::SmallVector newOperands; + newOperands.push_back( + rewriter.create(loc, funTy, callOp.getOperand(0))); + newOperands.push_back(arg); + newOperands.append(callOp.getOperands().begin() + 1, + callOp.getOperands().end()); + rewriter.create(loc, mlir::SymbolRefAttr{}, newResultTypes, + newOperands); + } + callOp->dropAllReferences(); + rewriter.eraseOp(callOp); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class SaveResultOpConversion + : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + SaveResultOpConversion(mlir::MLIRContext *context) + : OpRewritePattern(context) {} + mlir::LogicalResult + matchAndRewrite(fir::SaveResultOp op, + mlir::PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +class ReturnOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ReturnOpConversion(mlir::MLIRContext *context, + const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(mlir::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + rewriter.setInsertionPoint(ret); + auto returnedValue = ret.getOperand(0); + bool replacedStorage = false; + if (auto *op = returnedValue.getDefiningOp()) + if (auto load = mlir::dyn_cast(op)) { + auto resultStorage = load.memref(); + load.memref().replaceAllUsesWith(options.newArg); + replacedStorage = true; + if (auto *alloc = resultStorage.getDefiningOp()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + } + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result storage. + // at the return point. + if (!replacedStorage) + rewriter.create(ret.getLoc(), returnedValue, + options.newArg); + rewriter.replaceOpWithNewOp(ret); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class AddrOfOpConversion : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + AddrOfOpConversion(mlir::MLIRContext *context, + const AbstractResultOptions &opt) + : OpRewritePattern(context), options{opt} {} + mlir::LogicalResult + matchAndRewrite(fir::AddrOfOp addrOf, + mlir::PatternRewriter &rewriter) const override { + auto oldFuncTy = addrOf.getType().cast(); + auto newFuncTy = getNewFunctionType(oldFuncTy, options); + auto newAddrOf = rewriter.create(addrOf.getLoc(), newFuncTy, + addrOf.symbol()); + // Rather than converting all op a function pointer might transit through + // (e.g calls, stores, loads, converts...), cast new type to the abstract + // type. A conversion will be added when calling indirect calls of abstract + // types. + rewriter.replaceOpWithNewOp(addrOf, oldFuncTy, newAddrOf); + return mlir::success(); + } + +private: + const AbstractResultOptions &options; +}; + +class AbstractResultOpt : public fir::AbstractResultOptBase { +public: + void runOnOperation() override { + auto *context = &getContext(); + auto func = getOperation(); + auto loc = func.getLoc(); + mlir::OwningRewritePatternList patterns(context); + mlir::ConversionTarget target = *context; + AbstractResultOptions options{passResultAsBox.getValue(), + /*newArg=*/{}}; + + // Convert function type itself if it has an abstract result + auto funcTy = func.getType().cast(); + if (mustConvertCallOrFunc(funcTy)) { + func.setType(getNewFunctionType(funcTy, options)); + unsigned zero = 0; + if (!func.empty()) { + // Insert new argument + mlir::OpBuilder rewriter(context); + auto resultType = funcTy.getResult(0); + auto argTy = getResultArgumentType(resultType, options); + options.newArg = func.front().insertArgument(zero, argTy); + if (mustEmboxResult(resultType, options)) { + auto bufferType = fir::ReferenceType::get(resultType); + rewriter.setInsertionPointToStart(&func.front()); + options.newArg = + rewriter.create(loc, bufferType, options.newArg); + } + patterns.insert(context, options); + target.addDynamicallyLegalOp( + [](mlir::ReturnOp ret) { return ret.operands().empty(); }); + } + } + + if (func.empty()) + return; + + // Convert the calls and, if needed, the ReturnOp in the function body. + target.addLegalDialect(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](fir::CallOp call) { + return !mustConvertCallOrFunc(call.getFunctionType()); + }); + target.addDynamicallyLegalOp([](fir::AddrOfOp addrOf) { + if (auto funTy = addrOf.getType().dyn_cast()) + return !mustConvertCallOrFunc(funTy); + return true; + }); + target.addDynamicallyLegalOp([](fir::DispatchOp dispatch) { + if (dispatch->getNumResults() != 1) + return true; + auto resultType = dispatch->getResult(0).getType(); + if (resultType.isa()) { + mlir::emitError(dispatch.getLoc(), + "TODO: dispatchOp with abstract results"); + return false; + } + return true; + }); + + patterns.insert(context, options); + patterns.insert(context); + patterns.insert(context, options); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns)))) { + mlir::emitError(func.getLoc(), "error in converting abstract results\n"); + signalPassFailure(); + } + } +}; +} // end anonymous namespace +} // namespace fir + +std::unique_ptr fir::createAbstractResultOptPass() { + return std::make_unique(); +} diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_flang_library(FIRTransforms + AbstractResult.cpp AffinePromotion.cpp AffineDemotion.cpp CharacterConversion.cpp diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/abstract-results.fir @@ -0,0 +1,255 @@ +// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to +// functions that take an additional argument for the result. + +// RUN: fir-opt %s --abstract-result-opt | FileCheck %s +// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX + +// ----------------------- Test declaration rewrite ---------------------------- + +// CHECK-LABEL: func private @arrayfunc(!fir.ref>, i32) +// CHECK-BOX-LABEL: func private @arrayfunc(!fir.box>, i32) +func private @arrayfunc(i32) -> !fir.array + +// CHECK-LABEL: func private @derivedfunc(!fir.ref>, f32) +// CHECK-BOX-LABEL: func private @derivedfunc(!fir.box>, f32) +func private @derivedfunc(f32) -> !fir.type + +// CHECK-LABEL: func private @boxfunc(!fir.ref>>, i64) +// CHECK-BOX-LABEL: func private @boxfunc(!fir.ref>>, i64) +func private @boxfunc(i64) -> !fir.box> + + +// ------------------------ Test callee rewrite -------------------------------- + +// CHECK-LABEL: func private @arrayfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>, %[[n:.*]]: index) { +// CHECK-BOX-LABEL: func private @arrayfunc_callee( +// CHECK-BOX-SAME: %[[box:.*]]: !fir.box>, %[[n:.*]]: index) { +func private @arrayfunc_callee(%n : index) -> !fir.array { + %buffer = fir.alloca !fir.array, %n + // Do something with result (res(4) = 42.) + %c4 = constant 4 : i64 + %coor = fir.coordinate_of %buffer, %c4 : (!fir.ref>, i64) -> !fir.ref + %cst = constant 4.200000e+01 : f32 + fir.store %cst to %coor : !fir.ref + %res = fir.load %buffer : !fir.ref> + return %res : !fir.array + + // CHECK-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, i64) -> !fir.ref + // CHECK-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref + // CHECK: return + + // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box>) -> !fir.ref> + // CHECK-BOX-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, i64) -> !fir.ref + // CHECK-BOX-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref + // CHECK-BOX: return +} + + +// CHECK-LABEL: func @derivedfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>, %[[v:.*]]: f32) { +// CHECK-BOX-LABEL: func @derivedfunc_callee( +// CHECK-BOX-SAME: %[[box:.*]]: !fir.box>, %[[v:.*]]: f32) { +func @derivedfunc_callee(%v: f32) -> !fir.type { + %buffer = fir.alloca !fir.type + %0 = fir.field_index x, !fir.type + %1 = fir.coordinate_of %buffer, %0 : (!fir.ref>, !fir.field) -> !fir.ref + fir.store %v to %1 : !fir.ref + %res = fir.load %buffer : !fir.ref> + return %res : !fir.type + + // CHECK: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, !fir.field) -> !fir.ref + // CHECK: fir.store %[[v]] to %[[coor]] : !fir.ref + // CHECK: return + + // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box>) -> !fir.ref> + // CHECK-BOX: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref>, !fir.field) -> !fir.ref + // CHECK-BOX: fir.store %[[v]] to %[[coor]] : !fir.ref + // CHECK-BOX: return +} + +// CHECK-LABEL: func @boxfunc_callee( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref>>) { +// CHECK-BOX-LABEL: func @boxfunc_callee( +// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref>>) { +func @boxfunc_callee() -> !fir.box> { + %alloc = fir.allocmem f64 + %res = fir.embox %alloc : (!fir.heap) -> !fir.box> + return %res : !fir.box> + // CHECK: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap) -> !fir.box> + // CHECK: fir.store %[[box]] to %[[buffer]] : !fir.ref>> + // CHECK: return + + // CHECK-BOX: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap) -> !fir.box> + // CHECK-BOX: fir.store %[[box]] to %[[buffer]] : !fir.ref>> + // CHECK-BOX: return +} + +// ------------------------ Test caller rewrite -------------------------------- + +// CHECK-LABEL: func @call_arrayfunc() { +// CHECK-BOX-LABEL: func @call_arrayfunc() { +func @call_arrayfunc() { + %c100 = constant 100 : index + %buffer = fir.alloca !fir.array, %c100 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %res = fir.call @arrayfunc_callee(%c100) : (index) -> !fir.array + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK: fir.call @arrayfunc_callee(%[[buffer]], %[[c100]]) : (!fir.ref>, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + // CHECK-BOX: fir.call @arrayfunc_callee(%[[box]], %[[c100]]) : (!fir.box>, index) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// CHECK-LABEL: func @call_derivedfunc() { +// CHECK-BOX-LABEL: func @call_derivedfunc() { +func @call_derivedfunc() { + %buffer = fir.alloca !fir.type + %cst = constant 4.200000e+01 : f32 + %res = fir.call @derivedfunc_callee(%cst) : (f32) -> !fir.type + fir.save_result %res to %buffer : !fir.type, !fir.ref> + return + // CHECK: %[[buffer:.*]] = fir.alloca !fir.type + // CHECK: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK: fir.call @derivedfunc_callee(%[[buffer]], %[[cst]]) : (!fir.ref>, f32) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.type + // CHECK-BOX: %[[cst:.*]] = constant {{.*}} : f32 + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref>) -> !fir.box> + // CHECK-BOX: fir.call @derivedfunc_callee(%[[box]], %[[cst]]) : (!fir.box>, f32) -> () + // CHECK-BOX-NOT: fir.save_result +} + +func private @derived_lparams_func() -> !fir.type + +// CHECK-LABEL: func @call_derived_lparams_func( +// CHECK-SAME: %[[buffer:.*]]: !fir.ref> +// CHECK-BOX-LABEL: func @call_derived_lparams_func( +// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref> +func @call_derived_lparams_func(%buffer: !fir.ref>) { + %l1 = constant 3 : i32 + %l2 = constant 5 : i32 + %res = fir.call @derived_lparams_func() : () -> !fir.type + fir.save_result %res to %buffer typeparams %l1, %l2 : !fir.type, !fir.ref>, i32, i32 + return + + // CHECK: %[[l1:.*]] = constant 3 : i32 + // CHECK: %[[l2:.*]] = constant 5 : i32 + // CHECK: fir.call @derived_lparams_func(%[[buffer]]) : (!fir.ref>) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[l1:.*]] = constant 3 : i32 + // CHECK-BOX: %[[l2:.*]] = constant 5 : i32 + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] typeparams %[[l1]], %[[l2]] : (!fir.ref>, i32, i32) -> !fir.box> + // CHECK-BOX: fir.call @derived_lparams_func(%[[box]]) : (!fir.box>) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// CHECK-LABEL: func @call_boxfunc() { +// CHECK-BOX-LABEL: func @call_boxfunc() { +func @call_boxfunc() { + %buffer = fir.alloca !fir.box> + %res = fir.call @boxfunc_callee() : () -> !fir.box> + fir.save_result %res to %buffer: !fir.box>, !fir.ref>> + return + + // CHECK: %[[buffer:.*]] = fir.alloca !fir.box> + // CHECK: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref>>) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.box> + // CHECK-BOX: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref>>) -> () + // CHECK-BOX-NOT: fir.save_result +} + +func private @chararrayfunc(index, index) -> !fir.array> + +// CHECK-LABEL: func @call_chararrayfunc() { +// CHECK-BOX-LABEL: func @call_chararrayfunc() { +func @call_chararrayfunc() { + %c100 = constant 100 : index + %c50 = constant 50 : index + %buffer = fir.alloca !fir.array>(%c100 : index), %c50 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %res = fir.call @chararrayfunc(%c100, %c50) : (index, index) -> !fir.array> + fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array>, !fir.ref>>, !fir.shape<1>, index + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[c50:.*]] = constant 50 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] + // CHECK: fir.call @chararrayfunc(%[[buffer]], %[[c100]], %[[c50]]) : (!fir.ref>>, index, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[c50:.*]] = constant 50 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array>(%[[c100]] : index), %[[c50]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) typeparams %[[c50]] : (!fir.ref>>, !fir.shape<1>, index) -> !fir.box>> + // CHECK-BOX: fir.call @chararrayfunc(%[[box]], %[[c100]], %[[c50]]) : (!fir.box>>, index, index) -> () + // CHECK-BOX-NOT: fir.save_result +} + +// ------------------------ Test fir.address_of rewrite ------------------------ + +func private @takesfuncarray((i32) -> !fir.array) + +// CHECK-LABEL: func @test_address_of() { +// CHECK-BOX-LABEL: func @test_address_of() { +func @test_address_of() { + %0 = fir.address_of(@arrayfunc) : (i32) -> !fir.array + fir.call @takesfuncarray(%0) : ((i32) -> !fir.array) -> () + return + + // CHECK: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.ref>, i32) -> () + // CHECK: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.ref>, i32) -> ()) -> ((i32) -> !fir.array) + // CHECK: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array) -> () + + // CHECK-BOX: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.box>, i32) -> () + // CHECK-BOX: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.box>, i32) -> ()) -> ((i32) -> !fir.array) + // CHECK-BOX: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array) -> () + +} + +// ----------------------- Test indirect calls rewrite ------------------------ + +// CHECK-LABEL: func @test_indirect_calls( +// CHECK-SAME: %[[arg0:.*]]: () -> ()) { +// CHECK-BOX-LABEL: func @test_indirect_calls( +// CHECK-BOX-SAME: %[[arg0:.*]]: () -> ()) { +func @test_indirect_calls(%arg0: () -> ()) { + %c100 = constant 100 : index + %buffer = fir.alloca !fir.array, %c100 + %shape = fir.shape %c100 : (index) -> !fir.shape<1> + %0 = fir.convert %arg0 : (() -> ()) -> ((index) -> !fir.array) + %res = fir.call %0(%c100) : (index) -> !fir.array + fir.save_result %res to %buffer(%shape) : !fir.array, !fir.ref>, !fir.shape<1> + return + + // CHECK: %[[c100:.*]] = constant 100 : index + // CHECK: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) + // CHECK: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array) -> ((!fir.ref>, index) -> ()) + // CHECK: fir.call %[[conv]](%[[buffer]], %c100) : (!fir.ref>, index) -> () + // CHECK-NOT: fir.save_result + + // CHECK-BOX: %[[c100:.*]] = constant 100 : index + // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array, %[[c100]] + // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1> + // CHECK-BOX: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array) + // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + // CHECK-BOX: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array) -> ((!fir.box>, index) -> ()) + // CHECK-BOX: fir.call %[[conv]](%[[box]], %c100) : (!fir.box>, index) -> () + // CHECK-BOX-NOT: fir.save_result +}