diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -237,8 +237,9 @@ } // COUNT() -template +template static Expr FoldCount(FoldingContext &context, FunctionRef &&ref) { + using LogicalResult = Type; static_assert(T::category == TypeCategory::Integer); ActualArguments &arg{ref.arguments()}; if (const Constant *mask{arg.empty() @@ -546,7 +547,18 @@ cx->u); } } else if (name == "count") { - return FoldCount(context, std::move(funcRef)); + int maskKind = args[0]->GetType()->kind(); + switch (maskKind) { + SWITCH_COVERS_ALL_CASES + case 1: + return FoldCount(context, std::move(funcRef)); + case 2: + return FoldCount(context, std::move(funcRef)); + case 4: + return FoldCount(context, std::move(funcRef)); + case 8: + return FoldCount(context, std::move(funcRef)); + } } else if (name == "digits") { if (const auto *cx{UnwrapExpr>(args[0])}) { return Expr{common::visit( diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -662,7 +662,7 @@ auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType, mlir::Value elem1, mlir::Value elem2) -> mlir::Value { - auto zero32 = builder.createIntegerConstant(loc, builder.getI32Type(), 0); + auto zero32 = builder.createIntegerConstant(loc, elementType, 0); auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -1161,6 +1161,54 @@ // CHECK: return %[[RES:.*]] : i64 // CHECK: } +// ----- +// Ensure count is properly simplified for different mask kind + +func.func @_QPdiffkind(%arg0: !fir.ref>> {fir.bindc_name = "mask"}) -> i32 { + %0 = fir.alloca i32 {bindc_name = "diffkind", uniq_name = "_QFdiffkindEdiffkind"} + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + %c0 = arith.constant 0 : index + %3 = fir.address_of(@_QQcl.916d74b25894ddf7881ff7f913a677f5) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %4 = fir.convert %2 : (!fir.box>>) -> !fir.box + %5 = fir.convert %3 : (!fir.ref>) -> !fir.ref + %6 = fir.convert %c0 : (index) -> i32 + %7 = fir.call @_FortranACount(%4, %5, %c5_i32, %6) fastmath : (!fir.box, !fir.ref, i32, i32) -> i64 + %8 = fir.convert %7 : (i64) -> i32 + fir.store %8 to %0 : !fir.ref + %9 = fir.load %0 : !fir.ref + return %9 : i32 +} + +// CHECK-LABEL: func.func @_QPdiffkind( +// CHECK-SAME: %[[A:.*]]: !fir.ref>> {fir.bindc_name = "mask"}) -> i32 { +// CHECK: %[[res:.*]] = fir.call @_FortranACountLogical2x1_simplified({{.*}}) fastmath : (!fir.box) -> i64 + +// CHECK-LABEL: func.func private @_FortranACountLogical2x1_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i64 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[C_INDEX0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_I16:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[IZERO:.*]] = arith.constant 0 : i64 +// CHECK: %[[C_INDEX1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I16]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[C_INDEX1]] : index +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[C_INDEX0]] to %[[EXTENT]] step %[[C_INDEX1]] iter_args(%[[COUNT:.*]] = %[[IZERO]]) -> (i64) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I16]], %[[ITER]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[I16_0:.*]] = arith.constant 0 : i16 +// CHECK: %[[I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[ITEM_VAL]], %[[I16_0]] : i16 +// CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %[[I64_0]], %[[I64_1]] : i64 +// CHECK: %[[NEW_COUNT:.*]] = arith.addi %[[SELECT]], %[[COUNT]] : i64 +// CHECK: fir.result %[[NEW_COUNT]] : i64 +// CHECK: } +// CHECK: return %[[RES:.*]] : i64 +// CHECK: } + // ----- // Ensure count isn't simplified when given dim argument