diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -47,9 +47,11 @@ class SimplifyIntrinsicsPass : public fir::SimplifyIntrinsicsBase { using FunctionTypeGeneratorTy = - std::function; + llvm::function_ref; using FunctionBodyGeneratorTy = - std::function; + llvm::function_ref; + using GenReductionBodyTy = llvm::function_ref; public: /// Generate a new function implementing a simplified version @@ -63,6 +65,16 @@ FunctionBodyGeneratorTy bodyGenerator); void runOnOperation() override; void getDependentDialects(mlir::DialectRegistry ®istry) const override; + +private: + /// Helper function to replace a reduction type of call with its + /// simplified form. The actual function is generated using a callback + /// function. + /// \p call is the call to be replaced + /// \p kindMap is used to create FIROpBuilder + /// \p genBodyFunc is the callback that builds the replacement function + void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc); }; } // namespace @@ -76,10 +88,10 @@ {elementType}); } -using BodyOpGeneratorTy = - std::function; -using InitValGeneratorTy = std::function; +using InitValGeneratorTy = llvm::function_ref; /// Generate the reduction loop into \p funcOp. @@ -427,6 +439,43 @@ } while (true); } +void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call, + const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc) { + mlir::SymbolRefAttr callee = call.getCalleeAttr(); + mlir::StringRef funcName = callee.getLeafReference().getValue(); + mlir::Operation::operand_range args = call.getArgs(); + // args[1] and args[2] are source filename and line number, ignored. + const mlir::Value &dim = args[3]; + const mlir::Value &mask = args[4]; + // dim is zero when it is absent, which is an implementation + // detail in the runtime library. + bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); + unsigned rank = getDimCount(args[0]); + if (dimAndMaskAbsent && rank == 1) { + mlir::Location loc = call.getLoc(); + mlir::Type type; + fir::FirOpBuilder builder(call, kindMap); + if (funcName.endswith("Integer4")) { + type = mlir::IntegerType::get(builder.getContext(), 32); + } else if (funcName.endswith("Real8")) { + type = mlir::FloatType::getF64(builder.getContext()); + } else { + return; + } + auto typeGenerator = [&type](fir::FirOpBuilder &builder) { + return genNoneBoxType(builder, type); + }; + mlir::func::FuncOp newFunc = + getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc); + auto newCall = + builder.create(loc, newFunc, mlir::ValueRange{args[0]}); + call->replaceAllUsesWith(newCall.getResults()); + call->dropAllReferences(); + call->erase(); + } +} + void SimplifyIntrinsicsPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); mlir::ModuleOp module = getOperation(); @@ -445,37 +494,7 @@ // int dim, const Descriptor *mask) // if (funcName.startswith("_FortranASum")) { - mlir::Operation::operand_range args = call.getArgs(); - // args[1] and args[2] are source filename and line number, ignored. - const mlir::Value &dim = args[3]; - const mlir::Value &mask = args[4]; - // dim is zero when it is absent, which is an implementation - // detail in the runtime library. - bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); - unsigned rank = getDimCount(args[0]); - if (dimAndMaskAbsent && rank == 1) { - mlir::Location loc = call.getLoc(); - mlir::Type type; - fir::FirOpBuilder builder(op, kindMap); - if (funcName.endswith("Integer4")) { - type = mlir::IntegerType::get(builder.getContext(), 32); - } else if (funcName.endswith("Real8")) { - type = mlir::FloatType::getF64(builder.getContext()); - } else { - return; - } - auto typeGenerator = [&type](fir::FirOpBuilder &builder) { - return genNoneBoxType(builder, type); - }; - mlir::func::FuncOp newFunc = getOrCreateFunction( - builder, funcName, typeGenerator, genFortranASumBody); - auto newCall = builder.create( - loc, newFunc, mlir::ValueRange{args[0]}); - call->replaceAllUsesWith(newCall.getResults()); - call->dropAllReferences(); - call->erase(); - } - + simplifyReduction(call, kindMap, genFortranASumBody); return; } if (funcName.startswith("_FortranADotProduct")) { @@ -539,36 +558,8 @@ return; } if (funcName.startswith("_FortranAMaxval")) { - mlir::Operation::operand_range args = call.getArgs(); - const mlir::Value &dim = args[3]; - const mlir::Value &mask = args[4]; - // dim is zero when it is absent, which is an implementation - // detail in the runtime library. - bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); - unsigned rank = getDimCount(args[0]); - if (dimAndMaskAbsent && rank == 1) { - mlir::Location loc = call.getLoc(); - mlir::Type type; - fir::FirOpBuilder builder(op, kindMap); - if (funcName.endswith("Integer4")) { - type = mlir::IntegerType::get(builder.getContext(), 32); - } else if (funcName.endswith("Real8")) { - type = mlir::FloatType::getF64(builder.getContext()); - } else { - return; - } - auto typeGenerator = [&type](fir::FirOpBuilder &builder) { - return genNoneBoxType(builder, type); - }; - mlir::func::FuncOp newFunc = getOrCreateFunction( - builder, funcName, typeGenerator, genFortranAMaxvalBody); - auto newCall = builder.create( - loc, newFunc, mlir::ValueRange{args[0]}); - call->replaceAllUsesWith(newCall.getResults()); - call->dropAllReferences(); - call->erase(); - return; - } + simplifyReduction(call, kindMap, genFortranAMaxvalBody); + return; } } }