diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -180,10 +180,12 @@ // end function RTNAME(Sum)_simplified auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType) { - return elementType.isa() - ? builder.createRealConstant(loc, elementType, - llvm::APFloat(0.0)) - : builder.createIntegerConstant(loc, elementType, 0); + if (auto ty = elementType.dyn_cast()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant(loc, elementType, + llvm::APFloat::getZero(sem)); + } + return builder.createIntegerConstant(loc, elementType, 0); }; auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, @@ -464,17 +466,22 @@ unsigned rank = getDimCount(args[0]); if (dimAndMaskAbsent && rank == 1) { mlir::Location loc = call.getLoc(); - mlir::Type type; fir::FirOpBuilder builder(call, kindMap); - if (funcName.endswith("Integer4")) { - type = mlir::IntegerType::get(builder.getContext(), 32); - } else if (funcName.endswith("Real8")) { - type = mlir::FloatType::getF64(builder.getContext()); - } else { + + // Support only floating point and integer results now. + mlir::Type resultType = call.getResult(0).getType(); + if (!resultType.isa() && + !resultType.isa()) return; - } - auto typeGenerator = [&type](fir::FirOpBuilder &builder) { - return genNoneBoxType(builder, type); + + auto argType = getArgElementType(args[0]); + if (!argType) + return; + assert(*argType == resultType && + "Argument/result types mismatch in reduction"); + + auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { + return genNoneBoxType(builder, resultType); }; mlir::func::FuncOp newFunc = getOrCreateFunction(builder, funcName, typeGenerator, genBodyFunc); diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -153,6 +153,65 @@ // ----- +// Call to SUM with 1D F32 is replaced. +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} { + func.func @sum_1d_real(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> f32 { + %c10 = arith.constant 10 : index + %0 = fir.alloca f32 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"} + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + %3 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %5 = fir.convert %2 : (!fir.box>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.convert %3 : (!fir.box) -> !fir.box + %9 = fir.call @_FortranASumReal4(%5, %6, %c5_i32, %7, %8) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> f32 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : f32 + } + func.func private @_FortranASumReal4(!fir.box, !fir.ref, i32, i32, !fir.box) -> f32 attributes {fir.runtime} + fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> { + %0 = fir.string_lit "./isum_5.f90\00"(13) : !fir.char<1,13> + fir.has_value %0 : !fir.char<1,13> + } +} + + +// CHECK-LABEL: func.func @sum_1d_real( +// CHECK-SAME: %[[A:.*]]: !fir.ref> {fir.bindc_name = "a"}) -> f32 { +// CHECK: %[[CINDEX_10:.*]] = arith.constant 10 : index +// CHECK: %[[SHAPE:.*]] = fir.shape %[[CINDEX_10]] : (index) -> !fir.shape<1> +// CHECK: %[[A_BOX_F32:.*]] = fir.embox %[[A]](%[[SHAPE]]) : (!fir.ref>, !fir.shape<1>) -> !fir.box> +// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_F32]] : (!fir.box>) -> !fir.box +// CHECK-NOT: fir.call @_FortranASumReal4({{.*}}) +// CHECK: %[[RES:.*]] = fir.call @_FortranASumReal4_simplified(%[[A_BOX_NONE]]) : (!fir.box) -> f32 +// CHECK-NOT: fir.call @_FortranASumReal4({{.*}}) +// CHECK: return %{{.*}} : f32 +// CHECK: } + +// CHECK-LABEL: func.func private @_FortranASumReal4_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[CINDEX_0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_F32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_F32]], %[[CINDEX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[CINDEX_1:.*]] = arith.constant 1 : index +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[CINDEX_1]] : index +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[CINDEX_0]] to %[[EXTENT]] step %[[CINDEX_1]] iter_args(%[[SUM]] = %[[ZERO]]) -> (f32) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_F32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[NEW_SUM:.*]] = arith.addf %[[ITEM_VAL]], %[[SUM]] : f32 +// CHECK: fir.result %[[NEW_SUM]] : f32 +// CHECK: } +// CHECK: return %[[RES]] : f32 +// CHECK: } + +// ----- + // Call to SUM with 1D COMPLEX array is not replaced. module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.target_triple = "native"} { func.func @sum_1d_complex(%arg0: !fir.ref>> {fir.bindc_name = "a"}) -> !fir.complex<4> {