diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -15,6 +15,7 @@ #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Builder/IntrinsicCall.h" #include "flang/Optimizer/Builder/MutableBox.h" #include "flang/Optimizer/Builder/Runtime/Assign.h" #include "flang/Optimizer/Builder/Todo.h" @@ -498,6 +499,163 @@ } }; +/// Base class for passes converting transformational intrinsic operations into +/// runtime calls +template +class HlfirIntrinsicConversion : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + +protected: + struct IntrinsicArgument { + mlir::Value val; // allowed to be null if the argument is absent + mlir::Type desiredType; + }; + + /// Lower the arguments to the intrinsic: adding nesecarry boxing and + /// conversion to match the signature of the intrinsic in the runtime library. + llvm::SmallVector + lowerArguments(mlir::Operation *op, + const llvm::ArrayRef &args, + mlir::ConversionPatternRewriter &rewriter, + const fir::IntrinsicArgumentLoweringRules *argLowering) const { + assert(args.size() == 3 && "Transformational intrinsics have 3 args"); + mlir::Location loc = op->getLoc(); + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + + llvm::SmallVector ret; + + for (size_t i = 0; i < args.size(); ++i) { + mlir::Value arg = args[i].val; + mlir::Type desiredType = args[i].desiredType; + if (!arg) { + ret.emplace_back(fir::getAbsentIntrinsicArgument()); + continue; + } + hlfir::Entity entity{arg}; + + fir::ArgLoweringRule argRules = + fir::lowerIntrinsicArgumentAs(*argLowering, i); + if (argRules.handleDynamicOptional) + TODO(loc, "handleDynamicOptional"); + switch (argRules.lowerAs) { + case fir::LowerIntrinsicArgAs::Value: { + if (args[i].desiredType != arg.getType()) { + arg = builder.createConvert(loc, desiredType, arg); + entity = hlfir::Entity{arg}; + } + auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity); + if (cleanup) + TODO(loc, "extended value cleanup"); + ret.emplace_back(exv); + } break; + case fir::LowerIntrinsicArgAs::Addr: { + auto [exv, cleanup] = + hlfir::convertToAddress(loc, builder, entity, desiredType); + if (cleanup) + TODO(loc, "extended value cleanup"); + ret.emplace_back(exv); + } break; + case fir::LowerIntrinsicArgAs::Box: { + auto [box, cleanup] = + hlfir::convertToBox(loc, builder, entity, desiredType); + if (cleanup) + TODO(loc, "extended value cleanup"); + ret.emplace_back(box); + } break; + case fir::LowerIntrinsicArgAs::Inquired: { + if (args[i].desiredType != arg.getType()) { + arg = builder.createConvert(loc, desiredType, arg); + entity = hlfir::Entity{arg}; + } + // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities + // are translated to fir::ExtendedValue without transofrmation (notably, + // pointers/allocatable are not dereferenced). + // TODO: once lowering to FIR retires, UBOUND and LBOUND can be + // simplified since the fir.box lowered here are now guarenteed to + // contain the local lower bounds thanks to the hlfir.declare (the extra + // rebox can be removed). + auto [exv, cleanup] = + hlfir::translateToExtendedValue(loc, builder, entity); + if (cleanup) + TODO(loc, "extended value cleanup"); + ret.emplace_back(exv); + } break; + } + } + + return ret; + } + + void processReturnValue(mlir::Operation *op, + const fir::ExtendedValue &resultExv, bool mustBeFreed, + fir::FirOpBuilder &builder, + mlir::PatternRewriter &rewriter) const { + mlir::Location loc = op->getLoc(); + + mlir::Value firBase = fir::getBase(resultExv); + mlir::Type firBaseTy = firBase.getType(); + + std::optional resultEntity; + if (fir::isa_trivial(firBaseTy)) { + resultEntity = hlfir::EntityWithAttributes{firBase}; + } else { + resultEntity = + hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result", + fir::FortranVariableFlagsAttr{}); + } + + if (!fir::isa_trivial(resultEntity->getType())) { + hlfir::AsExprOp asExpr = builder.create( + loc, *resultEntity, builder.createBool(loc, mustBeFreed)); + resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()}; + } + + rewriter.replaceOp(op, resultEntity->getBase()); + } +}; + +struct SumOpConversion : public HlfirIntrinsicConversion { + using HlfirIntrinsicConversion::HlfirIntrinsicConversion; + + mlir::LogicalResult + matchAndRewrite(hlfir::SumOp sum, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + fir::KindMapping kindMapping{rewriter.getContext()}; + fir::FirOpBuilder builder{rewriter, kindMapping}; + const mlir::Location &loc = sum->getLoc(); + HLFIRListener listener{builder, rewriter}; + builder.setListener(&listener); + + auto oldInsertPoint = builder.saveInsertionPoint(); + builder.setInsertionPoint(sum); + + mlir::Type i32 = builder.getI32Type(); + constexpr int DEFAULT_LOGICAL_KIND = 4; + mlir::Type logicalType = + fir::LogicalType::get(builder.getContext(), DEFAULT_LOGICAL_KIND); + + llvm::SmallVector inArgs; + inArgs.push_back({sum.getArray(), sum.getArray().getType()}); + inArgs.push_back({sum.getDim(), i32}); + inArgs.push_back({sum.getMask(), logicalType}); + + auto *argLowering = fir::getIntrinsicArgumentLowering("sum"); + llvm::SmallVector args = + lowerArguments(sum, inArgs, rewriter, argLowering); + + mlir::Type expr = sum->getResult(0).getType(); + mlir::Type scalarResultType = hlfir::getFortranElementType(expr); + + auto [resultExv, mustBeFreed] = + fir::genIntrinsicCall(builder, loc, "sum", scalarResultType, args); + + processReturnValue(sum, resultExv, mustBeFreed, builder, rewriter); + builder.restoreInsertionPoint(oldInsertPoint); + return mlir::success(); + } +}; + class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: void runOnOperation() override { @@ -515,7 +673,8 @@ .insert(context); + NoReassocOpConversion, SetLengthOpConversion, SumOpConversion>( + context); mlir::ConversionTarget target(*context); target.addIllegalOp