diff --git a/flang/include/flang/Lower/HlfirIntrinsics.h b/flang/include/flang/Lower/HlfirIntrinsics.h new file mode 100644 --- /dev/null +++ b/flang/include/flang/Lower/HlfirIntrinsics.h @@ -0,0 +1,90 @@ +//===-- HlfirIntrinsics.h -- lowering to HLFIR intrinsic ops ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// +/// +/// Implements lowering of transformational intrinsics to HLFIR intrinsic +/// operations +/// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_LOWER_HLFIRINTRINSICS_H +#define FORTRAN_LOWER_HLFIRINTRINSICS_H + +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include + +namespace mlir { +class Location; +class Type; +class Value; +class ValueRange; +} // namespace mlir + +namespace fir { +class FirOpBuilder; +struct IntrinsicArgumentLoweringRules; +} // namespace fir + +namespace Fortran::lower { + +/// This structure holds the initial lowered value of an actual argument that +/// was lowered regardless of the interface, and it holds whether or not it +/// may be absent at runtime and the dummy is optional. +struct PreparedActualArgument { + + PreparedActualArgument(hlfir::Entity actual, + std::optional isPresent) + : actual{actual}, isPresent{isPresent} {} + void setElementalIndices(mlir::ValueRange &indices) { + oneBasedElementalIndices = &indices; + } + hlfir::Entity getActual(mlir::Location loc, + fir::FirOpBuilder &builder) const { + if (oneBasedElementalIndices) + return hlfir::getElementAt(loc, builder, actual, + *oneBasedElementalIndices); + return actual; + } + hlfir::Entity getOriginalActual() const { return actual; } + void setOriginalActual(hlfir::Entity newActual) { actual = newActual; } + bool handleDynamicOptional() const { return isPresent.has_value(); } + mlir::Value getIsPresent() const { + assert(handleDynamicOptional() && "not a dynamic optional"); + return *isPresent; + } + + void resetOptionalAspect() { isPresent = std::nullopt; } + +private: + hlfir::Entity actual; + mlir::ValueRange *oneBasedElementalIndices{nullptr}; + // When the actual may be dynamically optional, "isPresent" + // holds a boolean value indicating the presence of the + // actual argument at runtime. + std::optional isPresent; +}; + +/// Vector of pre-lowered actual arguments. nullopt if the actual is +/// "statically" absent (if it was not syntactically provided). +using PreparedActualArguments = + llvm::SmallVector>; + +std::optional lowerHlfirIntrinsic( + fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name, + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType); + +} // namespace Fortran::lower +#endif // FORTRAN_LOWER_HLFIRINTRINSICS_H diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -6,6 +6,7 @@ Bridge.cpp CallInterface.cpp Coarray.cpp + ComponentPath.cpp ConvertArrayConstructor.cpp ConvertCall.cpp ConvertConstant.cpp @@ -14,9 +15,9 @@ ConvertProcedureDesignator.cpp ConvertType.cpp ConvertVariable.cpp - ComponentPath.cpp CustomIntrinsicCall.cpp DumpEvaluateExpr.cpp + HlfirIntrinsics.cpp HostAssociations.cpp IO.cpp IterationSpace.cpp diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -14,6 +14,7 @@ #include "flang/Lower/ConvertExprToHLFIR.h" #include "flang/Lower/ConvertVariable.h" #include "flang/Lower/CustomIntrinsicCall.h" +#include "flang/Lower/HlfirIntrinsics.h" #include "flang/Lower/StatementContext.h" #include "flang/Lower/SymbolMap.h" #include "flang/Optimizer/Builder/BoxValue.h" @@ -615,50 +616,8 @@ std::optional resultType; mlir::Location loc; }; - -/// This structure holds the initial lowered value of an actual argument that -/// was lowered regardless of the interface, and it holds whether or not it -/// may be absent at runtime and the dummy is optional. -struct PreparedActualArgument { - - PreparedActualArgument(hlfir::Entity actual, - std::optional isPresent) - : actual{actual}, isPresent{isPresent} {} - void setElementalIndices(mlir::ValueRange &indices) { - oneBasedElementalIndices = &indices; - } - hlfir::Entity getActual(mlir::Location loc, - fir::FirOpBuilder &builder) const { - if (oneBasedElementalIndices) - return hlfir::getElementAt(loc, builder, actual, - *oneBasedElementalIndices); - return actual; - } - hlfir::Entity getOriginalActual() const { return actual; } - void setOriginalActual(hlfir::Entity newActual) { actual = newActual; } - bool handleDynamicOptional() const { return isPresent.has_value(); } - mlir::Value getIsPresent() const { - assert(handleDynamicOptional() && "not a dynamic optional"); - return *isPresent; - } - - void resetOptionalAspect() { isPresent = std::nullopt; } - -private: - hlfir::Entity actual; - mlir::ValueRange *oneBasedElementalIndices{nullptr}; - // When the actual may be dynamically optional, "isPresent" - // holds a boolean value indicating the presence of the - // actual argument at runtime. - std::optional isPresent; -}; } // namespace -/// Vector of pre-lowered actual arguments. nullopt if the actual is -/// "statically" absent (if it was not syntactically provided). -using PreparedActualArguments = - llvm::SmallVector>; - // Helper to transform a fir::ExtendedValue to an hlfir::EntityWithAttributes. static hlfir::EntityWithAttributes extendedValueToHlfirEntity(mlir::Location loc, fir::FirOpBuilder &builder, @@ -860,7 +819,8 @@ /// The optional aspects must be handled by this function user. static PreparedDummyArgument preparePresentUserCallActualArgument( mlir::Location loc, fir::FirOpBuilder &builder, - const PreparedActualArgument &preparedActual, mlir::Type dummyType, + const Fortran::lower::PreparedActualArgument &preparedActual, + mlir::Type dummyType, const Fortran::lower::CallerInterface::PassedEntity &arg, const Fortran::lower::SomeExpr &expr, Fortran::lower::AbstractConverter &converter) { @@ -1023,7 +983,8 @@ /// of any optional aspect. static PreparedDummyArgument prepareUserCallActualArgument( mlir::Location loc, fir::FirOpBuilder &builder, - const PreparedActualArgument &preparedActual, mlir::Type dummyType, + const Fortran::lower::PreparedActualArgument &preparedActual, + mlir::Type dummyType, const Fortran::lower::CallerInterface::PassedEntity &arg, const Fortran::lower::SomeExpr &expr, Fortran::lower::AbstractConverter &converter) { @@ -1094,7 +1055,7 @@ /// the array argument elements value and will return the corresponding /// scalar result value. static std::optional -genUserCall(PreparedActualArguments &loweredActuals, +genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals, Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType, CallContext &callContext) { using PassBy = Fortran::lower::CallerInterface::PassEntityBy; @@ -1221,7 +1182,7 @@ /// Lower calls to intrinsic procedures with actual arguments that have been /// pre-lowered but have not yet been prepared according to the interface. static std::optional -genIntrinsicRefCore(PreparedActualArguments &loweredActuals, +genIntrinsicRefCore(Fortran::lower::PreparedActualArguments &loweredActuals, const Fortran::evaluate::SpecificIntrinsic *intrinsic, const fir::IntrinsicArgumentLoweringRules *argLowering, CallContext &callContext) { @@ -1343,199 +1304,29 @@ /// Lower calls to intrinsic procedures with actual arguments that have been /// pre-lowered but have not yet been prepared according to the interface. -static std::optional -genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals, - const Fortran::evaluate::SpecificIntrinsic *intrinsic, - const fir::IntrinsicArgumentLoweringRules *argLowering, - CallContext &callContext) { +static std::optional genHLFIRIntrinsicRefCore( + Fortran::lower::PreparedActualArguments &loweredActuals, + const Fortran::evaluate::SpecificIntrinsic *intrinsic, + const fir::IntrinsicArgumentLoweringRules *argLowering, + CallContext &callContext) { if (!useHlfirIntrinsicOps) return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext); fir::FirOpBuilder &builder = callContext.getBuilder(); mlir::Location loc = callContext.loc; - - auto getOperandVector = [&](PreparedActualArguments &loweredActuals) { - llvm::SmallVector operands; - operands.reserve(loweredActuals.size()); - - for (size_t i = 0; i < loweredActuals.size(); ++i) { - std::optional arg = loweredActuals[i]; - if (!arg) { - operands.emplace_back(); - continue; - } - hlfir::Entity actual = arg->getOriginalActual(); - mlir::Value valArg; - - // if intrinsic handler has no lowering rules - if (!argLowering) { - valArg = hlfir::loadTrivialScalar(loc, builder, actual); - } else { - fir::ArgLoweringRule argRules = - fir::lowerIntrinsicArgumentAs(*argLowering, i); - if (!argRules.handleDynamicOptional && - argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired) - valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual); - else - valArg = actual.getBase(); - } - - operands.emplace_back(valArg); - } - return operands; - }; - - auto computeResultType = [&](mlir::Value argArray, - mlir::Type stmtResultType) -> mlir::Type { - hlfir::ExprType::Shape resultShape; - mlir::Type normalisedResult = - hlfir::getFortranElementOrSequenceType(stmtResultType); - mlir::Type elementType; - if (auto array = normalisedResult.dyn_cast()) { - resultShape = hlfir::ExprType::Shape{array.getShape()}; - elementType = array.getEleTy(); - return hlfir::ExprType::get(builder.getContext(), resultShape, - elementType, - /*polymorphic=*/false); - } - elementType = normalisedResult; - return elementType; - }; - - auto buildSumOperation = [](fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Type resultTy, mlir::Value array, - mlir::Value dim, mlir::Value mask) { - return builder.create(loc, resultTy, array, dim, mask); - }; - - auto buildProductOperation = [](fir::FirOpBuilder &builder, - mlir::Location loc, mlir::Type resultTy, - mlir::Value array, mlir::Value dim, - mlir::Value mask) { - return builder.create(loc, resultTy, array, dim, mask); - }; - - auto buildAnyOperation = [](fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Type resultTy, mlir::Value array, - mlir::Value dim, mlir::Value mask) { - return builder.create(loc, resultTy, array, dim); - }; - - auto buildAllOperation = [](fir::FirOpBuilder &builder, mlir::Location loc, - mlir::Type resultTy, mlir::Value array, - mlir::Value dim, mlir::Value mask) { - return builder.create(loc, resultTy, array, dim); - }; - - auto buildReductionIntrinsic = - [&](PreparedActualArguments &loweredActuals, mlir::Location loc, - fir::FirOpBuilder &builder, CallContext &callContext, - std::function - buildFunc, - bool hasMask) -> std::optional { - // shared logic for building the product and sum operations - llvm::SmallVector operands = getOperandVector(loweredActuals); - // dim, mask can be NULL if these arguments were not given - mlir::Value array = operands[0]; - mlir::Value dim = operands[1]; - if (dim) - dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); - - mlir::Value mask; - if (hasMask) - mask = operands[2]; - - mlir::Type resultTy = computeResultType(array, *callContext.resultType); - auto *intrinsicOp = buildFunc(builder, loc, resultTy, array, dim, mask); - return {hlfir::EntityWithAttributes{intrinsicOp->getResult(0)}}; - }; - const std::string intrinsicName = callContext.getProcedureName(); - if (intrinsicName == "sum") { - return buildReductionIntrinsic(loweredActuals, loc, builder, callContext, - buildSumOperation, true); - } - if (intrinsicName == "product") { - return buildReductionIntrinsic(loweredActuals, loc, builder, callContext, - buildProductOperation, true); - } - if (intrinsicName == "matmul") { - llvm::SmallVector operands = getOperandVector(loweredActuals); - mlir::Type resultTy = - computeResultType(operands[0], *callContext.resultType); - hlfir::MatmulOp matmulOp = builder.create( - loc, resultTy, operands[0], operands[1]); - - return {hlfir::EntityWithAttributes{matmulOp.getResult()}}; - } - if (intrinsicName == "transpose") { - llvm::SmallVector operands = getOperandVector(loweredActuals); - hlfir::ExprType::Shape resultShape; - mlir::Type normalisedResult = - hlfir::getFortranElementOrSequenceType(*callContext.resultType); - auto array = normalisedResult.cast(); - llvm::ArrayRef arrayShape = array.getShape(); - assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2"); - mlir::Type elementType = array.getEleTy(); - resultShape.push_back(arrayShape[0]); - resultShape.push_back(arrayShape[1]); - mlir::Type resultTy = hlfir::ExprType::get( - builder.getContext(), resultShape, elementType, /*polymorphic=*/false); - hlfir::TransposeOp transposeOp = - builder.create(loc, resultTy, operands[0]); - - return {hlfir::EntityWithAttributes{transposeOp.getResult()}}; - } - if (intrinsicName == "any") { - return buildReductionIntrinsic(loweredActuals, loc, builder, callContext, - buildAnyOperation, false); - } - if (intrinsicName == "all") { - return buildReductionIntrinsic(loweredActuals, loc, builder, callContext, - buildAllOperation, false); - } - if (intrinsicName == "dot_product") { - llvm::SmallVector operands = getOperandVector(loweredActuals); - mlir::Type resultTy = - computeResultType(operands[0], *callContext.resultType); - hlfir::DotProductOp dotProductOp = builder.create( - loc, resultTy, operands[0], operands[1]); - - return {hlfir::EntityWithAttributes{dotProductOp.getResult()}}; - } - if (intrinsicName == "count") { - llvm::SmallVector operands = getOperandVector(loweredActuals); - mlir::Value array = operands[0]; - mlir::Value dim = operands[1]; - if (dim) - dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); - mlir::Value kind = operands[2]; - mlir::Type resultTy = computeResultType(array, *callContext.resultType); - hlfir::CountOp countOp = - builder.create(loc, resultTy, array, dim, kind); - return {hlfir::EntityWithAttributes{countOp.getResult()}}; - } - if ((intrinsicName == "min" || intrinsicName == "max") && - hlfir::getFortranElementType(callContext.resultType.value()) - .isa()) { - llvm::SmallVector operands = getOperandVector(loweredActuals); - assert(operands.size() >= 2); - - hlfir::CharExtremumPredicate pred = (intrinsicName == "min") - ? hlfir::CharExtremumPredicate::min - : hlfir::CharExtremumPredicate::max; - hlfir::CharExtremumOp charExtremumOp = - builder.create(loc, pred, - mlir::ValueRange{operands}); - return {hlfir::EntityWithAttributes{charExtremumOp.getResult()}}; + // transformational intrinsic ops always have a result type + if (callContext.resultType) { + std::optional res = + Fortran::lower::lowerHlfirIntrinsic(builder, loc, intrinsicName, + loweredActuals, argLowering, + *callContext.resultType); + if (res) + return res; } - // TODO add hlfir operations for other transformational intrinsics here - // fallback to calling the intrinsic via fir.call return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext); @@ -1546,14 +1337,14 @@ class ElementalCallBuilder { public: std::optional - genElementalCall(PreparedActualArguments &loweredActuals, bool isImpure, - CallContext &callContext) { + genElementalCall(Fortran::lower::PreparedActualArguments &loweredActuals, + bool isImpure, CallContext &callContext) { mlir::Location loc = callContext.loc; fir::FirOpBuilder &builder = callContext.getBuilder(); unsigned numArgs = loweredActuals.size(); // Step 1: dereference pointers/allocatables and compute elemental shape. mlir::Value shape; - PreparedActualArgument *optionalWithShape; + Fortran::lower::PreparedActualArgument *optionalWithShape; // 10.1.4 p5. Impure elemental procedures must be called in element order. bool mustBeOrdered = isImpure; for (unsigned i = 0; i < numArgs; ++i) { @@ -1693,7 +1484,7 @@ mlir::FunctionType callSiteType) : caller{caller}, callSiteType{callSiteType} {} std::optional - genElementalKernel(PreparedActualArguments &loweredActuals, + genElementalKernel(Fortran::lower::PreparedActualArguments &loweredActuals, CallContext &callContext) { return genUserCall(loweredActuals, caller, callSiteType, callContext); } @@ -1714,9 +1505,9 @@ arg.passBy == PassBy::BaseAddressValueAttribute; } - mlir::Value - computeDynamicCharacterResultLength(PreparedActualArguments &loweredActuals, - CallContext &callContext) { + mlir::Value computeDynamicCharacterResultLength( + Fortran::lower::PreparedActualArguments &loweredActuals, + CallContext &callContext) { TODO(callContext.loc, "compute elemental function result length parameters in HLFIR"); } @@ -1735,7 +1526,7 @@ : intrinsic{intrinsic}, argLowering{argLowering}, isFunction{isFunction} { } std::optional - genElementalKernel(PreparedActualArguments &loweredActuals, + genElementalKernel(Fortran::lower::PreparedActualArguments &loweredActuals, CallContext &callContext) { return genHLFIRIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext); @@ -1748,9 +1539,9 @@ return isFunction; } - mlir::Value - computeDynamicCharacterResultLength(PreparedActualArguments &loweredActuals, - CallContext &callContext) { + mlir::Value computeDynamicCharacterResultLength( + Fortran::lower::PreparedActualArguments &loweredActuals, + CallContext &callContext) { if (intrinsic) if (intrinsic->name == "adjustr" || intrinsic->name == "adjustl" || intrinsic->name == "merge") @@ -1816,7 +1607,7 @@ callContext.procRef, *intrinsic, converter)) TODO(loc, "special cases of intrinsic with optional arguments"); - PreparedActualArguments loweredActuals; + Fortran::lower::PreparedActualArguments loweredActuals; const fir::IntrinsicArgumentLoweringRules *argLowering = fir::getIntrinsicArgumentLowering(callContext.getProcedureName()); for (const auto &arg : llvm::enumerate(callContext.procRef.arguments())) { @@ -1845,7 +1636,7 @@ !fir::lowerIntrinsicArgumentAs(*argLowering, arg.index()) .handleDynamicOptional) && "TYPE(*) are not expected to appear as optional intrinsic arguments"); - loweredActuals.push_back(PreparedActualArgument{ + loweredActuals.push_back(Fortran::lower::PreparedActualArgument{ hlfir::Entity{*var}, /*isPresent=*/std::nullopt}); continue; } @@ -1861,7 +1652,8 @@ genIsPresentIfArgMaybeAbsent(loc, loweredActual, *expr, callContext, /*passAsAllocatableOrPointer=*/false); } - loweredActuals.push_back(PreparedActualArgument{loweredActual, isPresent}); + loweredActuals.push_back( + Fortran::lower::PreparedActualArgument{loweredActual, isPresent}); } if (callContext.isElementalProcWithArrayArgs()) { @@ -1898,7 +1690,7 @@ callContext.converter); mlir::FunctionType callSiteType = caller.genFunctionType(); - PreparedActualArguments loweredActuals; + Fortran::lower::PreparedActualArguments loweredActuals; // Lower the actual arguments for (const Fortran::lower::CallInterface< Fortran::lower::CallerInterface>::PassedEntity &arg : @@ -1933,7 +1725,7 @@ Fortran::lower::CallerInterface::PassEntityBy::MutableBox); loweredActuals.emplace_back( - PreparedActualArgument{loweredActual, isPresent}); + Fortran::lower::PreparedActualArgument{loweredActual, isPresent}); } else { // Optional dummy argument for which there is no actual argument. loweredActuals.emplace_back(std::nullopt); diff --git a/flang/lib/Lower/HlfirIntrinsics.cpp b/flang/lib/Lower/HlfirIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Lower/HlfirIntrinsics.cpp @@ -0,0 +1,296 @@ +//===-- HlfirIntrinsics.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Lower/HlfirIntrinsics.h" + +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Builder/IntrinsicCall.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace { + +class HlfirTransformationalIntrinsic { +public: + explicit HlfirTransformationalIntrinsic(fir::FirOpBuilder &builder, + mlir::Location loc) + : builder(builder), loc(loc) {} + + virtual ~HlfirTransformationalIntrinsic() = default; + + hlfir::EntityWithAttributes + lower(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + mlir::Value res = lowerImpl(loweredActuals, argLowering, stmtResultType); + return {hlfir::EntityWithAttributes{res}}; + } + +protected: + fir::FirOpBuilder &builder; + mlir::Location loc; + + virtual mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) = 0; + + llvm::SmallVector getOperandVector( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering); + + mlir::Type computeResultType(mlir::Value argArray, mlir::Type stmtResultType); + + template + inline OP createOp(BUILD_ARGS... args) { + return builder.create(loc, args...); + } +}; + +template +class HlfirReductionIntrinsic : public HlfirTransformationalIntrinsic { +public: + using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; +using HlfirSumLowering = HlfirReductionIntrinsic; +using HlfirProductLowering = HlfirReductionIntrinsic; +using HlfirAnyLowering = HlfirReductionIntrinsic; +using HlfirAllLowering = HlfirReductionIntrinsic; + +template +class HlfirProductIntrinsic : public HlfirTransformationalIntrinsic { +public: + using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; +using HlfirMatmulLowering = HlfirProductIntrinsic; +using HlfirDotProductLowering = HlfirProductIntrinsic; + +class HlfirTransposeLowering : public HlfirTransformationalIntrinsic { +public: + using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; + +class HlfirCountLowering : public HlfirTransformationalIntrinsic { +public: + using HlfirTransformationalIntrinsic::HlfirTransformationalIntrinsic; + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; +}; + +class HlfirCharExtremumLowering : public HlfirTransformationalIntrinsic { +public: + HlfirCharExtremumLowering(fir::FirOpBuilder &builder, mlir::Location loc, + hlfir::CharExtremumPredicate pred) + : HlfirTransformationalIntrinsic(builder, loc), pred{pred} {} + +protected: + mlir::Value + lowerImpl(const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) override; + +protected: + hlfir::CharExtremumPredicate pred; +}; + +} // namespace + +llvm::SmallVector HlfirTransformationalIntrinsic::getOperandVector( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering) { + llvm::SmallVector operands; + operands.reserve(loweredActuals.size()); + + for (size_t i = 0; i < loweredActuals.size(); ++i) { + std::optional arg = + loweredActuals[i]; + if (!arg) { + operands.emplace_back(); + continue; + } + hlfir::Entity actual = arg->getOriginalActual(); + mlir::Value valArg; + + if (!argLowering) { + valArg = hlfir::loadTrivialScalar(loc, builder, actual); + } else { + fir::ArgLoweringRule argRules = + fir::lowerIntrinsicArgumentAs(*argLowering, i); + if (!argRules.handleDynamicOptional && + argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired) + valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual); + else + valArg = actual.getBase(); + } + + operands.emplace_back(valArg); + } + return operands; +} + +mlir::Type +HlfirTransformationalIntrinsic::computeResultType(mlir::Value argArray, + mlir::Type stmtResultType) { + mlir::Type normalisedResult = + hlfir::getFortranElementOrSequenceType(stmtResultType); + if (auto array = normalisedResult.dyn_cast()) { + hlfir::ExprType::Shape resultShape = + hlfir::ExprType::Shape{array.getShape()}; + mlir::Type elementType = array.getEleTy(); + return hlfir::ExprType::get(builder.getContext(), resultShape, elementType, + /*polymorphic=*/false); + } + return normalisedResult; +} + +template +mlir::Value HlfirReductionIntrinsic::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + mlir::Value array = operands[0]; + mlir::Value dim = operands[1]; + // dim, mask can be NULL if these arguments are not given + if (dim) + dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); + + mlir::Type resultTy = computeResultType(array, stmtResultType); + + OP op; + if constexpr (HAS_MASK) + op = createOp(resultTy, array, dim, /*mask=*/operands[2]); + else + op = createOp(resultTy, array, dim); + return op; +} + +template +mlir::Value HlfirProductIntrinsic::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + mlir::Type resultType = computeResultType(operands[0], stmtResultType); + return createOp(resultType, operands[0], operands[1]); +} + +mlir::Value HlfirTransposeLowering::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + hlfir::ExprType::Shape resultShape; + mlir::Type normalisedResult = + hlfir::getFortranElementOrSequenceType(stmtResultType); + auto array = normalisedResult.cast(); + llvm::ArrayRef arrayShape = array.getShape(); + assert(arrayShape.size() == 2 && "arguments to transpose have a rank of 2"); + mlir::Type elementType = array.getEleTy(); + resultShape.push_back(arrayShape[0]); + resultShape.push_back(arrayShape[1]); + mlir::Type resultTy = hlfir::ExprType::get( + builder.getContext(), resultShape, elementType, /*polymorphic=*/false); + return createOp(resultTy, operands[0]); +} + +mlir::Value HlfirCountLowering::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + mlir::Value array = operands[0]; + mlir::Value dim = operands[1]; + if (dim) + dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim}); + mlir::Value kind = operands[2]; + mlir::Type resultType = computeResultType(array, stmtResultType); + return createOp(resultType, array, dim, kind); +} + +mlir::Value HlfirCharExtremumLowering::lowerImpl( + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + auto operands = getOperandVector(loweredActuals, argLowering); + assert(operands.size() >= 2); + return createOp(pred, mlir::ValueRange{operands}); +} + +std::optional Fortran::lower::lowerHlfirIntrinsic( + fir::FirOpBuilder &builder, mlir::Location loc, const std::string &name, + const Fortran::lower::PreparedActualArguments &loweredActuals, + const fir::IntrinsicArgumentLoweringRules *argLowering, + mlir::Type stmtResultType) { + if (name == "sum") + return HlfirSumLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (name == "product") + return HlfirProductLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (name == "any") + return HlfirAnyLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (name == "all") + return HlfirAllLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (name == "matmul") + return HlfirMatmulLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (name == "dot_product") + return HlfirDotProductLowering{builder, loc}.lower( + loweredActuals, argLowering, stmtResultType); + if (name == "transpose") + return HlfirTransposeLowering{builder, loc}.lower( + loweredActuals, argLowering, stmtResultType); + if (name == "count") + return HlfirCountLowering{builder, loc}.lower(loweredActuals, argLowering, + stmtResultType); + if (mlir::isa(stmtResultType)) { + if (name == "min") + return HlfirCharExtremumLowering{builder, loc, + hlfir::CharExtremumPredicate::min} + .lower(loweredActuals, argLowering, stmtResultType); + if (name == "max") + return HlfirCharExtremumLowering{builder, loc, + hlfir::CharExtremumPredicate::max} + .lower(loweredActuals, argLowering, stmtResultType); + } + return std::nullopt; +}