diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -60,6 +60,7 @@ std::unique_ptr createMemDataFlowOptPass(); std::unique_ptr createPromoteToAffinePass(); std::unique_ptr createMemoryAllocationPass(); +std::unique_ptr createConstExtruderPass(); std::unique_ptr createStackArraysPass(); std::unique_ptr createSimplifyIntrinsicsPass(); std::unique_ptr createAddDebugFoundationPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -242,6 +242,15 @@ let constructor = "::fir::createMemoryAllocationPass()"; } +def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::func::FuncOp"> { + let summary = "Convert scalar literals of function arguments to global constants."; + let description = [{ + Convert scalar literals of function arguments to global constants. + }]; + let dependentDialects = [ "fir::FIROpsDialect" ]; + let constructor = "::fir::createConstExtruderPass()"; +} + def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> { let summary = "Move local array allocations from heap memory into stack memory"; let description = [{ diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -252,6 +252,7 @@ pm.addPass(hlfir::createLowerHLFIRIntrinsicsPass()); pm.addPass(hlfir::createBufferizeHLFIRPass()); pm.addPass(hlfir::createConvertHLFIRtoFIRPass()); + pm.addPass(fir::createConstExtruderPass()); } #if !defined(FLANG_EXCLUDE_CODEGEN) diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ AffineDemotion.cpp AnnotateConstant.cpp CharacterConversion.cpp + ConstExtruder.cpp ControlFlowConverter.cpp ArrayValueCopy.cpp ExternalNameConversion.cpp diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp @@ -0,0 +1,190 @@ +//===- ConstExtruder.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace fir { +#define GEN_PASS_DEF_CONSTEXTRUDEROPT +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +#define DEBUG_TYPE "flang-const-extruder-opt" + +namespace { +std::atomic uniqueLitId = 1; + +static bool needsExtrusion(const mlir::Value *a) { + if (!a || !a->getDefiningOp()) + return false; + + // is alloca + if (auto alloca = mlir::dyn_cast_or_null(a->getDefiningOp())) { + // alloca has annotation + if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) { + for (mlir::Operation *s : alloca.getOperation()->getUsers()) { + if (const auto store = mlir::dyn_cast_or_null(s)) { + auto constant_def = store->getOperand(0).getDefiningOp(); + // Expect constant definition operation + if (mlir::isa(constant_def)) { + return true; + } + } + } + } + } + return false; +} + +class CallOpRewriter : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + CallOpRewriter(mlir::MLIRContext *ctx) : OpRewritePattern(ctx) {} + + mlir::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + + auto module = callOp->getParentOfType(); + fir::FirOpBuilder builder(rewriter, module); + llvm::SmallVector newOperands; + llvm::SmallVector toErase; + for (const auto &a : callOp.getArgs()) { + if (auto alloca = + mlir::dyn_cast_or_null(a.getDefiningOp())) { + if (a.getDefiningOp()->hasAttr(fir::getAdaptToByRefAttrName())) { + mlir::Type varTy = alloca.getInType(); + assert(!fir::hasDynamicSize(varTy) && + "only expect statically sized scalars to be by value"); + + // find immediate store with const argument + llvm::SmallVector stores; + for (mlir::Operation *s : alloca.getOperation()->getUsers()) + if (mlir::isa(s)) + stores.push_back(s); + assert(stores.size() == 1 && "expected exactly one store"); + LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n"); + toErase.push_back(stores[0]); + + auto constant_def = stores[0]->getOperand(0).getDefiningOp(); + // Expect constant definition operation or force legalisation of the + // callOp and continue with its next argument + if (!mlir::isa(constant_def)) { + // unable to remove alloca arg + newOperands.push_back(a); + continue; + } + + toErase.push_back(alloca); + LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n"); + + auto loc = callOp.getLoc(); + llvm::StringRef globalPrefix = "_extruded_"; + + std::string globalName; + while (!globalName.length() || builder.getNamedGlobal(globalName)) + globalName = + globalPrefix.str() + "." + std::to_string(uniqueLitId++); + + auto global = builder.createGlobalConstant( + loc, varTy, globalName, + [&](fir::FirOpBuilder &builder) { + mlir::Operation *cln = constant_def->clone(); + builder.insert(cln); + fir::ExtendedValue exv{cln->getResult(0)}; + mlir::Value valBase = fir::getBase(exv); + mlir::Value val = builder.createConvert(loc, varTy, valBase); + builder.create(loc, val); + }, + builder.createInternalLinkage()); + mlir::Value ope = {builder.create( + loc, global.resultType(), global.getSymbol())}; + newOperands.push_back(ope); + } else { + // alloca but without attr, add it + newOperands.push_back(a); + } + } else { + // non-alloca operand, add it + newOperands.push_back(a); + } + } + + auto loc = callOp.getLoc(); + llvm::SmallVector newResultTypes; + newResultTypes.append(callOp.getResultTypes().begin(), + callOp.getResultTypes().end()); + fir::CallOp newOp = rewriter.create( + loc, newResultTypes, + callOp.getCallee().has_value() ? callOp.getCallee().value() + : mlir::SymbolRefAttr{}, + newOperands, callOp.getFastmathAttr()); + rewriter.replaceOp(callOp, newOp); + + for (auto e : toErase) + rewriter.eraseOp(e); + + LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as " + << newOp << '\n'); + return mlir::success(); + } +}; + +// This pass attempts to convert immediate scalar literals in function calls +// to global constants to allow transformations as Dead Argument Elimination +class ConstExtruderOpt + : public fir::impl::ConstExtruderOptBase { +public: + ConstExtruderOpt() {} + + void runOnOperation() override { + auto *context = &getContext(); + auto func = getOperation(); + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target(*context); + + // If func is a declaration, skip it. + if (func.empty()) + return; + + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](fir::CallOp op) { + for (auto a : op.getArgs()) { + if (needsExtrusion(&a)) + return false; + } + return true; + }); + + patterns.insert(context); + if (mlir::failed( + mlir::applyPartialConversion(func, target, std::move(patterns)))) { + mlir::emitError(func.getLoc(), + "error in constant extrusion optimization\n"); + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr fir::createConstExtruderPass() { + return std::make_unique(); +} diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -31,6 +31,8 @@ ! ALL-NEXT: LowerHLFIRIntrinsics ! ALL-NEXT: BufferizeHLFIR ! ALL-NEXT: ConvertHLFIRtoFIR +! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: ConstExtruderOpt ! ALL-NEXT: CSE ! Ideally, we need an output with only the pass names, but ! there is currently no way to get that, so in order to diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -20,6 +20,8 @@ ! ALL-NEXT: LowerHLFIRIntrinsics ! ALL-NEXT: BufferizeHLFIR ! ALL-NEXT: ConvertHLFIRtoFIR +! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: ConstExtruderOpt ! ALL-NEXT: CSE ! Ideally, we need an output with only the pass names, but ! there is currently no way to get that, so in order to diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -30,6 +30,8 @@ // PASSES-NEXT: LowerHLFIRIntrinsics // PASSES-NEXT: BufferizeHLFIR // PASSES-NEXT: ConvertHLFIRtoFIR +// PASSES-NEXT: 'func.func' Pipeline +// PASSES-NEXT: ConstExtruderOpt // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir --- a/flang/test/Fir/boxproc.fir +++ b/flang/test/Fir/boxproc.fir @@ -16,9 +16,7 @@ // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr // CHECK-SAME: %[[VAL_0:.*]]) -// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4 -// CHECK: store i32 4, ptr %[[VAL_1]], align 4 -// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]]) +// CHECK: call void %[[VAL_0]](ptr @{{.*}}) func.func @_QPtest_proc_dummy() { %c0_i32 = arith.constant 0 : i32