diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -274,6 +274,12 @@ $array $optionals attr-dict `:` functional-type(operands, results) }]; + // dim and mask can be NULL, array must not be. + let builders = [OpBuilder<(ins "mlir::Value":$array, + "mlir::Value":$dim, + "mlir::Value":$mask, + "mlir::Type":$resultType)>]; + let hasVerifier = 1; } diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -548,7 +548,8 @@ const fir::ExtendedValue &exv, llvm::StringRef name) { mlir::Value firBase = fir::getBase(exv); - if (fir::isa_trivial(firBase.getType())) + mlir::Type firBaseTy = firBase.getType(); + if (fir::isa_trivial(firBaseTy) || firBaseTy.isa()) return hlfir::EntityWithAttributes{firBase}; return hlfir::genDeclare(loc, builder, exv, name, fir::FortranVariableFlagsAttr{}); @@ -768,6 +769,47 @@ return resultEntity; } +/// Lower calls to intrinsic procedures with actual arguments that have been +/// pre-lowered but have not yet been prepared according to the interface. +static hlfir::EntityWithAttributes genHLFIRIntrinsicRefCore( + PreparedActualArguments &loweredActuals, + const Fortran::evaluate::SpecificIntrinsic &intrinsic, + const Fortran::lower::IntrinsicArgumentLoweringRules *argLowering, + CallContext &callContext) { + llvm::SmallVector operands; + operands.reserve(loweredActuals.size()); + for (auto arg : llvm::enumerate(loweredActuals)) { + if (!arg.value()) { + operands.emplace_back(); + continue; + } + hlfir::Entity actual = arg.value()->actual; + operands.emplace_back(actual.getBase()); + } + fir::FirOpBuilder &builder = callContext.getBuilder(); + mlir::Location loc = callContext.loc; + + if (intrinsic.name == "sum") { + assert(operands.size() == 3); + mlir::Value array = operands[0]; + assert(array); + mlir::Value dim = operands[1]; + mlir::Value mask = operands[2]; + // dim, mask can be NULL if these arguments were not given + hlfir::SumOp sumOp = builder.create(loc, array, dim, mask, + *callContext.resultType); + hlfir::EntityWithAttributes resultEntity = extendedValueToHlfirEntity( + loc, builder, sumOp.getResult(), ".tmp.intrinsic_result"); + return resultEntity; + } + + // TODO add hlfir operations for other transformational intrinsics here + + // fallback to calling the intrinsic via fir.call + return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, + callContext); +} + namespace { template class ElementalCallBuilder { @@ -928,8 +970,8 @@ std::optional genElementalKernel(PreparedActualArguments &loweredActuals, CallContext &callContext) { - return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, - callContext); + return genHLFIRIntrinsicRefCore(loweredActuals, intrinsic, argLowering, + callContext); } // Elemental intrinsic functions cannot modify their arguments. bool argMayBeModifiedByCall(int) const { return !isFunction; } @@ -1002,8 +1044,8 @@ .genElementalCall(loweredActuals, /*isImpure=*/!isFunction, callContext) .value(); } - hlfir::EntityWithAttributes result = - genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext); + hlfir::EntityWithAttributes result = genHLFIRIntrinsicRefCore( + loweredActuals, intrinsic, argLowering, callContext); if (result.getType().isa()) { fir::FirOpBuilder *bldr = &callContext.getBuilder(); callContext.stmtCtx.attachCleanup( diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -544,6 +544,46 @@ return mlir::success(); } +void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, + mlir::Value array, mlir::Value dim, mlir::Value mask, + mlir::Type stmtResultType) { + assert(array && "array argument is not optional"); + mlir::Type arrayTy = unwrapType(array.getType()); + mlir::Type numTy; + if (auto arrayExprTy = arrayTy.dyn_cast()) { + numTy = arrayExprTy.getEleTy(); + } else if (auto arraySeqTy = arrayTy.dyn_cast()) { + numTy = arraySeqTy.getEleTy(); + } else { + llvm_unreachable("bad array type"); + } + assert(hlfir::isFortranScalarNumericalType(numTy)); + + // get the result shape from the statement context + hlfir::ExprType::Shape resultShape; + if (stmtResultType) { + if (auto array = stmtResultType.dyn_cast()) { + assert(array.getEleTy() == numTy && "Unexpected array result type"); + assert(dim && "Non-scalar result type without DIM argument"); + resultShape = hlfir::ExprType::Shape{array.getShape()}; + } else { + assert(stmtResultType == numTy && "Scalar result should match ARRAY arg"); + // we might get a scalar result even with a DIM argument if the array + // has rank 1, so don't assert !dim + } + } + mlir::Type resultType = hlfir::ExprType::get( + builder.getContext(), resultShape, numTy, /*polymorphic=*/false); + + llvm::SmallVector optionals; + for (mlir::Value arg : {dim, mask}) { + if (arg) + optionals.push_back(arg); + } + + build(builder, result, resultType, array, optionals); +} + //===----------------------------------------------------------------------===// // AssociateOp //===----------------------------------------------------------------------===//