diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -38,6 +38,7 @@ std::unique_ptr createMemoryAllocationPass(); std::unique_ptr createPromoteToAffinePass(); std::unique_ptr createSimplifyRegionLitePass(); +std::unique_ptr createSimplifyIntrinsicsPass(); std::unique_ptr createMemoryAllocationPass(bool dynOnHeap, std::size_t maxStackSize); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -239,6 +239,14 @@ let constructor = "::fir::createSimplifyRegionLitePass()"; } +def SimplifyIntrinsics : Pass<"simplify-intrinsics", "mlir::ModuleOp"> { + let summary = "Intrinsics simplification"; + let description = [{ + Replace intrinsics functions and subroutines with inlineable code + }]; + let constructor = "::fir::createSimplifyIntrinsicsPass()"; +} + def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::FuncOp"> { let summary = "Convert stack to heap allocations and vice versa."; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -15,7 +15,8 @@ MemToReg.cpp RewriteLoop.cpp SimplifyRegionLite.cpp - + InlineIntrinsics.cpp + DEPENDS FIRAnalysis FIRBuilder diff --git a/flang/lib/Optimizer/Transforms/InlineIntrinsics.cpp b/flang/lib/Optimizer/Transforms/InlineIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/InlineIntrinsics.cpp @@ -0,0 +1,170 @@ +//===- SimplifyRegionLite.cpp -- region simplification lite ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" + +namespace { + +template +static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) { + auto t = mlir::FloatType::getF32(context); + auto r = mlir::IntegerType::get(context, Bits); + return mlir::FunctionType::get(context, {t}, {r}); +} + +class SimplifyIntrinsicsPass + : public fir::SimplifyIntrinsicsBase { +public: + mlir::FuncOp getOrCreateFunction(const mlir::Location &loc, + fir::FirOpBuilder &builder, + const mlir::Type &type, + const mlir::StringRef &basename); + void runOnOperation() override; +}; + +void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { + op->replaceAllUsesWith(newValues); + op->dropAllReferences(); + op->erase(); +} + +mlir::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( + const mlir::Location &loc, fir::FirOpBuilder &builder, + const mlir::Type &type, const mlir::StringRef &basename) { + mlir::Twine name{basename, "_inline"}; + auto module = builder.getModule(); + // If we already have a function, just return it. + mlir::FuncOp newFunc = + fir::FirOpBuilder::getNamedFunction(module, name.str()); + if (newFunc) { + return newFunc; + } + + // Need to build the function! + mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); + auto ftype = mlir::FunctionType::get(builder.getContext(), {boxType}, {type}); + newFunc = fir::FirOpBuilder::createFunction(loc, module, name.str(), ftype); + + // Save the position of the original call. + mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); + builder.setInsertionPointToEnd(newFunc.addEntryBlock()); + + mlir::IndexType idxTy = builder.getIndexType(); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + + auto zero = type.isa() + ? builder.createRealConstant(loc, type, 0.0) + : builder.createIntegerConstant(loc, type, 0); + auto sum = builder.create(loc, type); + builder.create(loc, zero, sum); + + auto args = newFunc.front().getArguments(); + mlir::Value arg = args[0]; + + auto zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); + + fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; + auto arrTy = fir::SequenceType::get(flatShape, type); + auto boxArrTy = fir::BoxType::get(arrTy); + mlir::Value array = builder.create(loc, boxArrTy, arg); + auto dims = + builder.create(loc, idxTy, idxTy, idxTy, array, zeroIdx); + mlir::Value len = dims.getResult(1); + mlir::Value step = one; + + auto loop = builder.create(loc, zeroIdx, len, step); + + // Begin loop code + auto loopEndPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + + mlir::Type eleRefTy = builder.getRefType(type); + mlir::Value index = loop.getInductionVar(); + mlir::Value addr = + builder.create(loc, eleRefTy, array, index); + mlir::Value elem = builder.create(loc, addr); + mlir::Value sumVal = builder.create(loc, sum); + + mlir::Value res; + if (type.isa()) + res = builder.create(loc, elem, sumVal); + else + res = builder.create(loc, elem, sumVal); + builder.create(loc, res, sum); + // End of loop. + builder.restoreInsertionPoint(loopEndPt); + + mlir::Value resultVal = builder.create(loc, sum); + builder.create(loc, resultVal); + + // Now back to where we were adding code earlier... + builder.restoreInsertionPoint(insertPt); + + return newFunc; +} + +inline bool isAbsent(mlir::Value val) { + return mlir::isa_and_nonnull( + val.getDefiningOp()->getOperand(0).getDefiningOp()); +} + +inline bool isZero(mlir::Value val) { + mlir::Operation *defop = val.getDefiningOp()->getOperand(0).getDefiningOp(); + if (auto constOp = mlir::dyn_cast(defop)) + return constOp.value() == 0; + return false; +} + +void SimplifyIntrinsicsPass::runOnOperation() { + auto module = getOperation(); + module.walk([&](mlir::Operation *op) { + fir::KindMapping kindMap = fir::getKindMapping(module); + fir::FirOpBuilder builder(op, kindMap); + if (auto call = mlir::dyn_cast(op)) { + if (auto callee = + call.getCallableForCallee().dyn_cast()) { + auto func = builder.getNamedFunction(callee); + auto funcName = func.sym_name(); + auto args = call.getArgs(); + if (funcName.startswith("_FortranASum")) { + bool dimAndMaskAbsent = isZero(args[3]) && isAbsent(args[4]); + if (dimAndMaskAbsent) { + auto loc = call.getLoc(); + mlir::Type type; + if (funcName.endswith("Integer4")) { + type = mlir::IntegerType::get(builder.getContext(), 32); + } else if (funcName.endswith("Real8")) { + type = mlir::FloatType::getF64(builder.getContext()); + } else { + return; + } + auto newFunc = getOrCreateFunction(loc, builder, type, funcName); + auto newCall = builder.create( + loc, newFunc, mlir::ValueRange{args[0]}); + replaceOp(call, newCall.getResults()); + } + } + } + } + }); +} + +} // namespace + +std::unique_ptr fir::createSimplifyIntrinsicsPass() { + return std::make_unique(); +}