diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -35,6 +35,8 @@ std::unique_ptr createMemDataFlowOptPass(); std::unique_ptr createPromoteToAffinePass(); std::unique_ptr createMemoryAllocationPass(); +std::unique_ptr createSimplifyIntrinsicsPass(); + std::unique_ptr createMemoryAllocationPass(bool dynOnHeap, std::size_t maxStackSize); std::unique_ptr createAnnotateConstantOperandsPass(); diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -170,6 +170,14 @@ ]; } +def SimplifyIntrinsics : Pass<"simplify-intrinsics", "mlir::ModuleOp"> { + let summary = "Intrinsics simplification"; + let description = [{ + Replace intrinsics functions and subroutines with inlineable code + }]; + let constructor = "::fir::createSimplifyIntrinsicsPass()"; +} + def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> { let summary = "Convert stack to heap allocations and vice versa."; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ MemRefDataFlowOpt.cpp RewriteLoop.cpp SimplifyRegionLite.cpp + SimplifyIntrinsics.cpp DEPENDS FIRBuilder diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -0,0 +1,188 @@ +//===- SimplifyIntrinsics.cpp -- replace intrinsics with simpler form -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" + +namespace { + +class SimplifyIntrinsicsPass + : public fir::SimplifyIntrinsicsBase { +public: + mlir::func::FuncOp getOrCreateFunction(const mlir::Location &loc, + fir::FirOpBuilder &builder, + const mlir::Type &type, + const mlir::StringRef &basename); + void runOnOperation() override; +}; + +mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction( + const mlir::Location &loc, fir::FirOpBuilder &builder, + const mlir::Type &type, const mlir::StringRef &basename) { + std::string name = mlir::Twine{basename, "_simplified"}.str(); + mlir::ModuleOp module = builder.getModule(); + // If we already have a function, just return it. + mlir::func::FuncOp newFunc = + fir::FirOpBuilder::getNamedFunction(module, name); + if (newFunc) { + return newFunc; + } + + // Need to build the function! + mlir::Type boxType = fir::BoxType::get(builder.getNoneType()); + mlir::FunctionType ftype = + mlir::FunctionType::get(builder.getContext(), {boxType}, {type}); + newFunc = fir::FirOpBuilder::createFunction(loc, module, name, ftype); + auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR; + auto linkage = + mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); + newFunc->setAttr("llvm.linkage", linkage); + + // Save the position of the original call. + mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint(); + builder.setInsertionPointToEnd(newFunc.addEntryBlock()); + + mlir::IndexType idxTy = builder.getIndexType(); + + mlir::Value zero = type.isa() + ? builder.createRealConstant(loc, type, 0.0) + : builder.createIntegerConstant(loc, type, 0); + mlir::Value sum = builder.create(loc, type); + builder.create(loc, zero, sum); + + mlir::Block::BlockArgListType args = newFunc.front().getArguments(); + mlir::Value arg = args[0]; + + mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0); + + fir::SequenceType::Shape flatShape = {fir::SequenceType::getUnknownExtent()}; + mlir::Type arrTy = fir::SequenceType::get(flatShape, type); + mlir::Type boxArrTy = fir::BoxType::get(arrTy); + mlir::Value array = builder.create(loc, boxArrTy, arg); + auto dims = + builder.create(loc, idxTy, idxTy, idxTy, array, zeroIdx); + mlir::Value len = dims.getResult(1); + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); + mlir::Value step = one; + + // We use C indexing here, so len-1 as loopcount + mlir::Value loopCount = builder.create(loc, len, one); + auto loop = builder.create(loc, zeroIdx, loopCount, step); + + // Begin loop code + mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(loop.getBody()); + + mlir::Type eleRefTy = builder.getRefType(type); + mlir::Value index = loop.getInductionVar(); + mlir::Value addr = + builder.create(loc, eleRefTy, array, index); + mlir::Value elem = builder.create(loc, addr); + mlir::Value sumVal = builder.create(loc, sum); + + mlir::Value res; + if (type.isa()) + res = builder.create(loc, elem, sumVal); + else + res = builder.create(loc, elem, sumVal); + builder.create(loc, res, sum); + // End of loop. + builder.restoreInsertionPoint(loopEndPt); + + mlir::Value resultVal = builder.create(loc, sum); + builder.create(loc, resultVal); + + // Now back to where we were adding code earlier... + builder.restoreInsertionPoint(insertPt); + + return newFunc; +} + +inline bool isOperandAbsent(mlir::Value val) { + if (mlir::Operation *op = val.getDefiningOp()) + return mlir::isa_and_nonnull( + op->getOperand(0).getDefiningOp()); + return false; +} + +inline bool isZero(mlir::Value val) { + if (mlir::Operation *op = val.getDefiningOp()) + if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp()) + return mlir::matchPattern(defOp, mlir::m_Zero()); + return false; +} + +inline ::mlir::Value findShape(mlir::Value val) { + mlir::Operation *defOp = val.getDefiningOp(); + while (defOp) { + defOp = defOp->getOperand(0).getDefiningOp(); + if (fir::EmboxOp box = mlir::dyn_cast(defOp)) + return box.getShape(); + } + return {}; +} + +inline unsigned getDimCount(mlir::Value val) { + if (mlir::Value shapeVal = findShape(val)) { + mlir::Type resType = shapeVal.getDefiningOp()->getResultTypes()[0]; + return fir::getRankOfShapeType(resType); + } + return 0; +} + +void SimplifyIntrinsicsPass::runOnOperation() { + mlir::ModuleOp module = getOperation(); + fir::KindMapping kindMap = fir::getKindMapping(module); + module.walk([&](mlir::Operation *op) { + if (auto call = mlir::dyn_cast(op)) { + if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) { + mlir::StringRef funcName = callee.getLeafReference().getValue(); + if (funcName.startswith("_FortranASum")) { + mlir::Operation::operand_range args = call.getArgs(); + bool dimAndMaskAbsent = isZero(args[3]) && isOperandAbsent(args[4]); + unsigned dims = getDimCount(args[0]); + if (dimAndMaskAbsent && dims == 1) { + mlir::Location loc = call.getLoc(); + mlir::Type type; + fir::FirOpBuilder builder(op, kindMap); + if (funcName.endswith("Integer4")) { + type = mlir::IntegerType::get(builder.getContext(), 32); + } else if (funcName.endswith("Real8")) { + type = mlir::FloatType::getF64(builder.getContext()); + } else { + return; + } + mlir::func::FuncOp newFunc = + getOrCreateFunction(loc, builder, type, funcName); + auto newCall = builder.create( + loc, newFunc, mlir::ValueRange{args[0]}); + call->replaceAllUsesWith(newCall.getResults()); + call->dropAllReferences(); + call->erase(); + } + } + } + } + }); +} + +} // namespace + +std::unique_ptr fir::createSimplifyIntrinsicsPass() { + return std::make_unique(); +}