diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -60,6 +60,7 @@ std::unique_ptr createMemDataFlowOptPass(); std::unique_ptr createPromoteToAffinePass(); std::unique_ptr createMemoryAllocationPass(); +std::unique_ptr createConstExtruderPass(); std::unique_ptr createStackArraysPass(); std::unique_ptr createAliasTagsPass(); std::unique_ptr createSimplifyIntrinsicsPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -242,6 +242,15 @@ let constructor = "::fir::createMemoryAllocationPass()"; } +def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::func::FuncOp"> { + let summary = "Convert scalar literals of function arguments to global constants."; + let description = [{ + Convert scalar literals of function arguments to global constants. + }]; + let dependentDialects = [ "fir::FIROpsDialect" ]; + let constructor = "::fir::createConstExtruderPass()"; +} + def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> { let summary = "Move local array allocations from heap memory into stack memory"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -216,6 +216,8 @@ else fir::addMemoryAllocationOpt(pm); + pm.addPass(fir::createConstExtruderPass()); + // The default inliner pass adds the canonicalizer pass with the default // configuration. Create the inliner pass with tco config. llvm::StringMap pipelines; diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ AffineDemotion.cpp AnnotateConstant.cpp CharacterConversion.cpp + ConstExtruder.cpp ControlFlowConverter.cpp ArrayValueCopy.cpp ExternalNameConversion.cpp diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp @@ -0,0 +1,193 @@ +//===- ConstExtruder.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace fir { +#define GEN_PASS_DEF_CONSTEXTRUDEROPT +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +#define DEBUG_TYPE "flang-const-extruder-opt" + +namespace { +std::atomic uniqueLitId = 1; + +static bool needsExtrusion(const mlir::Value *a) { + if (!a || !a->getDefiningOp()) + return false; + + // is alloca + if (auto alloca = mlir::dyn_cast_or_null(a->getDefiningOp())) { + // alloca has annotation + if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) { + for (mlir::Operation *s : alloca.getOperation()->getUsers()) { + if (const auto store = mlir::dyn_cast_or_null(s)) { + auto constant_def = store->getOperand(0).getDefiningOp(); + // Expect constant definition operation + if (mlir::isa(constant_def)) { + return true; + } + } + } + } + } + return false; +} + +class CallOpRewriter : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + CallOpRewriter(mlir::MLIRContext *ctx) : OpRewritePattern(ctx) {} + + mlir::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + + auto module = callOp->getParentOfType(); + fir::FirOpBuilder builder(rewriter, module); + llvm::SmallVector newOperands; + llvm::SmallVector toErase; + for (const auto &a : callOp.getArgs()) { + if (auto alloca = + mlir::dyn_cast_or_null(a.getDefiningOp())) { + if (a.getDefiningOp()->hasAttr(fir::getAdaptToByRefAttrName()) && + needsExtrusion(&a)) { + + mlir::Type varTy = alloca.getInType(); + assert(!fir::hasDynamicSize(varTy) && + "only expect statically sized scalars to be by value"); + + // find immediate store with const argument + llvm::SmallVector stores; + for (mlir::Operation *s : alloca.getOperation()->getUsers()) + if (mlir::isa(s)) + stores.push_back(s); + assert(stores.size() == 1 && "expected exactly one store"); + LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n"); + + auto constant_def = stores[0]->getOperand(0).getDefiningOp(); + // Expect constant definition operation or force legalisation of the + // callOp and continue with its next argument + if (!mlir::isa(constant_def)) { + // unable to remove alloca arg + newOperands.push_back(a); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n"); + + auto loc = callOp.getLoc(); + llvm::StringRef globalPrefix = "_extruded_"; + + std::string globalName; + while (!globalName.length() || builder.getNamedGlobal(globalName)) + globalName = + globalPrefix.str() + "." + std::to_string(uniqueLitId++); + + toErase.push_back(stores[0]); + toErase.push_back(alloca); + + auto global = builder.createGlobalConstant( + loc, varTy, globalName, + [&](fir::FirOpBuilder &builder) { + mlir::Operation *cln = constant_def->clone(); + builder.insert(cln); + fir::ExtendedValue exv{cln->getResult(0)}; + mlir::Value valBase = fir::getBase(exv); + mlir::Value val = builder.createConvert(loc, varTy, valBase); + builder.create(loc, val); + }, + builder.createInternalLinkage()); + mlir::Value ope = {builder.create( + loc, global.resultType(), global.getSymbol())}; + newOperands.push_back(ope); + } else { + // alloca but without attr, add it + newOperands.push_back(a); + } + } else { + // non-alloca operand, add it + newOperands.push_back(a); + } + } + + auto loc = callOp.getLoc(); + llvm::SmallVector newResultTypes; + newResultTypes.append(callOp.getResultTypes().begin(), + callOp.getResultTypes().end()); + fir::CallOp newOp = builder.create( + loc, newResultTypes, + callOp.getCallee().has_value() ? callOp.getCallee().value() + : mlir::SymbolRefAttr{}, + newOperands, callOp.getFastmathAttr()); + rewriter.replaceOp(callOp, newOp); + + for (auto e : toErase) + rewriter.eraseOp(e); + + LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as " + << newOp << '\n'); + return mlir::success(); + } +}; + +// This pass attempts to convert immediate scalar literals in function calls +// to global constants to allow transformations as Dead Argument Elimination +class ConstExtruderOpt + : public fir::impl::ConstExtruderOptBase { +public: + ConstExtruderOpt() {} + + void runOnOperation() override { + auto *context = &getContext(); + auto func = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target(*context); + + // If func is a declaration, skip it. + if (func.empty()) + return; + + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](fir::CallOp op) { + for (auto a : op.getArgs()) { + if (needsExtrusion(&a)) + return false; + } + return true; + }); + + patterns.insert(context); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns)))) { + mlir::emitError(func.getLoc(), + "error in constant extrusion optimization\n"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr fir::createConstExtruderPass() { + return std::make_unique(); +} diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 --- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 +++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 @@ -31,6 +31,7 @@ ! CHECK-NEXT: 'func.func' Pipeline ! CHECK-NEXT: MemoryAllocationOpt +! CHECK-NEXT: ConstExtruderOpt ! CHECK-NEXT: Inliner ! CHECK-NEXT: SimplifyRegionLite diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -51,6 +51,7 @@ ! ALL-NEXT: 'func.func' Pipeline ! ALL-NEXT: MemoryAllocationOpt +! ALL-NEXT: ConstExtruderOpt ! ALL-NEXT: Inliner ! ALL-NEXT: SimplifyRegionLite diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -42,6 +42,7 @@ ! ALL-NEXT: 'func.func' Pipeline ! ALL-NEXT: MemoryAllocationOpt +! ALL-NEXT: ConstExtruderOpt ! ALL-NEXT: Inliner ! ALL-NEXT: SimplifyRegionLite diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -48,6 +48,7 @@ // PASSES-NEXT: 'func.func' Pipeline // PASSES-NEXT: MemoryAllocationOpt +// PASSES-NEXT: ConstExtruderOpt // PASSES-NEXT: Inliner // PASSES-NEXT: SimplifyRegionLite diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir --- a/flang/test/Fir/boxproc.fir +++ b/flang/test/Fir/boxproc.fir @@ -16,9 +16,7 @@ // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr // CHECK-SAME: %[[VAL_0:.*]]) -// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4 -// CHECK: store i32 4, ptr %[[VAL_1]], align 4 -// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]]) +// CHECK: call void %[[VAL_0]](ptr @{{.*}}) func.func @_QPtest_proc_dummy() { %c0_i32 = arith.constant 0 : i32 diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90 --- a/flang/test/Lower/character-local-variables.f90 +++ b/flang/test/Lower/character-local-variables.f90 @@ -116,8 +116,7 @@ subroutine assumed_length_param(n) character(*), parameter :: c(1)=(/"abcd"/) integer :: n - ! CHECK: %[[c4:.*]] = arith.constant 4 : i64 - ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref + ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref) -> () call take_int(len(c(n), kind=8)) end diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90 --- a/flang/test/Lower/dummy-arguments.f90 +++ b/flang/test/Lower/dummy-arguments.f90 @@ -2,9 +2,7 @@ ! CHECK-LABEL: _QQmain program test1 - ! CHECK-DAG: %[[TMP:.*]] = fir.alloca - ! CHECK-DAG: %[[TEN:.*]] = arith.constant - ! CHECK: fir.store %[[TEN]] to %[[TMP]] + ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref ! CHECK-NEXT: fir.call @_QFPfoo call foo(10) contains diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90 --- a/flang/test/Lower/host-associated.f90 +++ b/flang/test/Lower/host-associated.f90 @@ -448,11 +448,10 @@ ! CHECK-LABEL: func @_QPtest_proc_dummy_other( ! CHECK-SAME: %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) { -! CHECK: %[[VAL_1:.*]] = arith.constant 4 : i32 -! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref} -! CHECK: fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref ! CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref) -> ()) -! CHECK: fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref) -> () +! CHECK: %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref +! CHECK: fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref) -> () + ! CHECK: return ! CHECK: }