diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -73,14 +73,22 @@ void getDependentDialects(mlir::DialectRegistry ®istry) const override; private: - /// Helper function to replace a reduction type of call with its + /// Helper functions to replace a reduction type of call with its /// simplified form. The actual function is generated using a callback /// function. /// \p call is the call to be replaced /// \p kindMap is used to create FIROpBuilder /// \p genBodyFunc is the callback that builds the replacement function - void simplifyReduction(fir::CallOp call, const fir::KindMapping &kindMap, - GenReductionBodyTy genBodyFunc); + void simplifyIntOrFloatReduction(fir::CallOp call, + const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc); + void simplifyLogicalReduction(fir::CallOp call, + const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc); + void simplifyReductionBody(fir::CallOp call, const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc, + fir::FirOpBuilder &builder, + const mlir::StringRef &basename); }; } // namespace @@ -131,17 +139,18 @@ /// Generate the reduction loop into \p funcOp. /// +/// \p elementType is the type of the elements in the input array, +/// which may be different to the return type. /// \p initVal is a function, called to get the initial value for /// the reduction value /// \p genBody is called to fill in the actual reduciton operation /// for example add for SUM, MAX for MAXVAL, etc. /// \p rank is the rank of the input argument. -static void genReductionLoop(fir::FirOpBuilder &builder, +static void genReductionLoop(fir::FirOpBuilder &builder, mlir::Type elementType, mlir::func::FuncOp &funcOp, InitValGeneratorTy initVal, BodyOpGeneratorTy genBody, unsigned rank) { auto loc = mlir::UnknownLoc::get(builder.getContext()); - mlir::Type elementType = funcOp.getResultTypes()[0]; builder.setInsertionPointToEnd(funcOp.addEntryBlock()); mlir::IndexType idxTy = builder.getIndexType(); @@ -156,7 +165,8 @@ mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType); mlir::Type boxArrTy = fir::BoxType::get(arrTy); mlir::Value array = builder.create(loc, boxArrTy, arg); - mlir::Value init = initVal(builder, loc, elementType); + mlir::Type resultType = funcOp.getResultTypes()[0]; + mlir::Value init = initVal(builder, loc, resultType); llvm::SmallVector bounds; @@ -265,7 +275,9 @@ return {}; }; - genReductionLoop(builder, funcOp, zero, genBodyOp, rank); + mlir::Type elementType = funcOp.getResultTypes()[0]; + + genReductionLoop(builder, elementType, funcOp, zero, genBodyOp, rank); } static void genRuntimeMaxvalBody(fir::FirOpBuilder &builder, @@ -293,7 +305,38 @@ llvm_unreachable("unsupported type"); return {}; }; - genReductionLoop(builder, funcOp, init, genBodyOp, rank); + + mlir::Type elementType = funcOp.getResultTypes()[0]; + + genReductionLoop(builder, elementType, funcOp, init, genBodyOp, rank); +} + +static void genRuntimeCountBody(fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp, unsigned rank) { + auto zero = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType) { + unsigned bits = elementType.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, elementType, zeroInt); + }; + + auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType, mlir::Value elem1, + mlir::Value elem2) -> mlir::Value { + auto zero32 = builder.createIntegerConstant(loc, builder.getI32Type(), 0); + auto zero64 = builder.createIntegerConstant(loc, builder.getI64Type(), 0); + auto one64 = builder.createIntegerConstant(loc, builder.getI64Type(), 1); + + auto compare = builder.create( + loc, mlir::arith::CmpIPredicate::eq, elem1, zero32); + auto select = + builder.create(loc, compare, zero64, one64); + return builder.create(loc, select, elem2); + }; + + mlir::Type elementType = builder.getI32Type(); + + genReductionLoop(builder, elementType, funcOp, zero, genBodyOp, rank); } /// Generate function type for the simplified version of RTNAME(DotProduct) @@ -526,58 +569,99 @@ } while (true); } -void SimplifyIntrinsicsPass::simplifyReduction(fir::CallOp call, - const fir::KindMapping &kindMap, - GenReductionBodyTy genBodyFunc) { - mlir::SymbolRefAttr callee = call.getCalleeAttr(); - mlir::Operation::operand_range args = call.getArgs(); +void SimplifyIntrinsicsPass::simplifyIntOrFloatReduction( + fir::CallOp call, const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc) { // args[1] and args[2] are source filename and line number, ignored. + mlir::Operation::operand_range args = call.getArgs(); + const mlir::Value &dim = args[3]; const mlir::Value &mask = args[4]; // dim is zero when it is absent, which is an implementation // detail in the runtime library. + bool dimAndMaskAbsent = isZero(dim) && isOperandAbsent(mask); unsigned rank = getDimCount(args[0]); - if (dimAndMaskAbsent && rank > 0) { - mlir::Location loc = call.getLoc(); - fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; - std::string fmfString{getFastMathFlagsString(builder)}; - - // Support only floating point and integer results now. - mlir::Type resultType = call.getResult(0).getType(); - if (!resultType.isa() && - !resultType.isa()) - return; - - auto argType = getArgElementType(args[0]); - if (!argType) - return; - assert(*argType == resultType && - "Argument/result types mismatch in reduction"); - - auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { - return genNoneBoxType(builder, resultType); - }; - auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { - genBodyFunc(builder, funcOp, rank); - }; - // Mangle the function name with the rank value as "x". - std::string funcName = - (mlir::Twine{callee.getLeafReference().getValue(), "x"} + - mlir::Twine{rank} + - // We must mangle the generated function name with FastMathFlags - // value. - (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) - .str(); - mlir::func::FuncOp newFunc = - getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); - auto newCall = - builder.create(loc, newFunc, mlir::ValueRange{args[0]}); - call->replaceAllUsesWith(newCall.getResults()); - call->dropAllReferences(); - call->erase(); - } + + if (!(dimAndMaskAbsent && rank > 0)) + return; + + mlir::Type resultType = call.getResult(0).getType(); + + if (!resultType.isa() && + !resultType.isa()) + return; + + auto argType = getArgElementType(args[0]); + if (!argType) + return; + assert(*argType == resultType && + "Argument/result types mismatch in reduction"); + + mlir::SymbolRefAttr callee = call.getCalleeAttr(); + + fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; + std::string fmfString{getFastMathFlagsString(builder)}; + std::string funcName = + (mlir::Twine{callee.getLeafReference().getValue(), "x"} + + mlir::Twine{rank} + + // We must mangle the generated function name with FastMathFlags + // value. + (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) + .str(); + + simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName); +} + +void SimplifyIntrinsicsPass::simplifyLogicalReduction( + fir::CallOp call, const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc) { + + mlir::Operation::operand_range args = call.getArgs(); + const mlir::Value &dim = args[3]; + + if (!isZero(dim)) + return; + + unsigned rank = getDimCount(args[0]); + mlir::SymbolRefAttr callee = call.getCalleeAttr(); + + fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; + std::string funcName = + (mlir::Twine{callee.getLeafReference().getValue(), "x"} + + mlir::Twine{rank}) + .str(); + + simplifyReductionBody(call, kindMap, genBodyFunc, builder, funcName); +} + +void SimplifyIntrinsicsPass::simplifyReductionBody( + fir::CallOp call, const fir::KindMapping &kindMap, + GenReductionBodyTy genBodyFunc, fir::FirOpBuilder &builder, + const mlir::StringRef &funcName) { + + mlir::Operation::operand_range args = call.getArgs(); + + mlir::Type resultType = call.getResult(0).getType(); + unsigned rank = getDimCount(args[0]); + + mlir::Location loc = call.getLoc(); + + auto typeGenerator = [&resultType](fir::FirOpBuilder &builder) { + return genNoneBoxType(builder, resultType); + }; + auto bodyGenerator = [&rank, &genBodyFunc](fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { + genBodyFunc(builder, funcOp, rank); + }; + // Mangle the function name with the rank value as "x". + mlir::func::FuncOp newFunc = + getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); + auto newCall = + builder.create(loc, newFunc, mlir::ValueRange{args[0]}); + call->replaceAllUsesWith(newCall.getResults()); + call->dropAllReferences(); + call->erase(); } void SimplifyIntrinsicsPass::runOnOperation() { @@ -598,7 +682,7 @@ // int dim, const Descriptor *mask) // if (funcName.startswith(RTNAME_STRING(Sum))) { - simplifyReduction(call, kindMap, genRuntimeSumBody); + simplifyIntOrFloatReduction(call, kindMap, genRuntimeSumBody); return; } if (funcName.startswith(RTNAME_STRING(DotProduct))) { @@ -669,7 +753,11 @@ return; } if (funcName.startswith(RTNAME_STRING(Maxval))) { - simplifyReduction(call, kindMap, genRuntimeMaxvalBody); + simplifyIntOrFloatReduction(call, kindMap, genRuntimeMaxvalBody); + return; + } + if (funcName.startswith(RTNAME_STRING(Count))) { + simplifyLogicalReduction(call, kindMap, genRuntimeCountBody); return; } } diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -1098,3 +1098,119 @@ // CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f64 // CHECK-LABEL: func.func private @_FortranASumReal8x1_fast_simplified // CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f64 + +// ----- +// Ensure count is simplified in valid case + +func.func @_QMtestPcount_generate_mask(%arg0: !fir.ref {fir.bindc_name = "a"}) -> i32 { + %0 = fir.alloca i32 {bindc_name = "count_generate_mask", uniq_name = "_QMtestFcount_generate_maskEcount_generate_mask"} + %c10 = arith.constant 10 : index + %1 = fir.alloca !fir.array<10x!fir.logical<4>> {bindc_name = "mask", uniq_name = "_QMtestFcount_generate_maskEmask"} + %2 = fir.shape %c10 : (index) -> !fir.shape<1> + %3 = fir.embox %1(%2) : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F746573746661696C2E66393000) : !fir.ref> + %c10_i32 = arith.constant 10 : i32 + %5 = fir.convert %3 : (!fir.box>>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.call @_FortranACount(%5, %6, %c10_i32, %7) fastmath : (!fir.box, !fir.ref, i32, i32) -> i64 + %9 = fir.convert %8 : (i64) -> i32 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : i32 +} +func.func private @_FortranACount(!fir.box, !fir.ref, i32, i32) -> i64 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F746573746661696C2E66393000 constant : !fir.char<1,15> { + %0 = fir.string_lit "./test.f90\00"(15) : !fir.char<1,15> + fir.has_value %0 : !fir.char<1,15> +} + +// CHECK-LABEL: func.func @_QMtestPcount_generate_mask( +// CHECK-SAME: %[[A:.*]]: !fir.ref {fir.bindc_name = "a"}) -> i32 { +// CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1> +// CHECK: %[[A_BOX_LOGICAL:.*]] = fir.embox %{{.*}}(%[[SHAPE]]) : (!fir.ref>>, !fir.shape<1>) -> !fir.box>> +// CHECK: %[[A_BOX_NONE:.*]] = fir.convert %[[A_BOX_LOGICAL]] : (!fir.box>>) -> !fir.box +// CHECK-NOT: fir.call @_FortranACount({{.*}}) +// CHECK: %[[RES:.*]] = fir.call @_FortranACountx1_simplified(%[[A_BOX_NONE]]) fastmath : (!fir.box) -> i64 +// CHECK-NOT: fir.call @_FortranACount({{.*}}) +// CHECK: return %{{.*}} : i32 +// CHECK: } +// CHECK: func.func private @_FortranACount(!fir.box, !fir.ref, i32, i32) -> i64 attributes {fir.runtime} + +// CHECK-LABEL: func.func private @_FortranACountx1_simplified( +// CHECK-SAME: %[[ARR:.*]]: !fir.box) -> i64 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[C_INDEX0:.*]] = arith.constant 0 : index +// CHECK: %[[ARR_BOX_I32:.*]] = fir.convert %[[ARR]] : (!fir.box) -> !fir.box> +// CHECK: %[[IZERO:.*]] = arith.constant 0 : i64 +// CHECK: %[[C_INDEX1:.*]] = arith.constant 1 : index +// CHECK: %[[DIMIDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARR_BOX_I32]], %[[DIMIDX_0]] : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[EXTENT:.*]] = arith.subi %[[DIMS]]#1, %[[C_INDEX1]] : index +// CHECK: %[[RES:.*]] = fir.do_loop %[[ITER:.*]] = %[[C_INDEX0]] to %[[EXTENT]] step %[[C_INDEX1]] iter_args(%[[COUNT:.*]] = %[[IZERO]]) -> (i64) { +// CHECK: %[[ITEM:.*]] = fir.coordinate_of %[[ARR_BOX_I32]], %[[ITER]] : (!fir.box>, index) -> !fir.ref +// CHECK: %[[ITEM_VAL:.*]] = fir.load %[[ITEM]] : !fir.ref +// CHECK: %[[I32_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[I64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[I64_1:.*]] = arith.constant 1 : i64 +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[ITEM_VAL]], %[[I32_0]] : i32 +// CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %[[I64_0]], %[[I64_1]] : i64 +// CHECK: %[[NEW_COUNT:.*]] = arith.addi %[[SELECT]], %[[COUNT]] : i64 +// CHECK: fir.result %[[NEW_COUNT]] : i64 +// CHECK: } +// CHECK: return %[[RES:.*]] : i64 +// CHECK: } + +// ----- +// Ensure count isn't simplified when given dim argument + +func.func @_QMtestPcount_generate_mask(%arg0: !fir.ref>> {fir.bindc_name = "mask"}) -> !fir.array<10xi32> { + %0 = fir.alloca !fir.box>> + %c10 = arith.constant 10 : index + %c10_0 = arith.constant 10 : index + %c10_1 = arith.constant 10 : index + %1 = fir.alloca !fir.array<10xi32> {bindc_name = "res", uniq_name = "_QMtestFcount_generate_maskEres"} + %2 = fir.shape %c10_1 : (index) -> !fir.shape<1> + %3 = fir.array_load %1(%2) : (!fir.ref>, !fir.shape<1>) -> !fir.array<10xi32> + %c2_i32 = arith.constant 2 : i32 + %4 = fir.shape %c10, %c10_0 : (index, index) -> !fir.shape<2> + %5 = fir.embox %arg0(%4) : (!fir.ref>>, !fir.shape<2>) -> !fir.box>> + %c4 = arith.constant 4 : index + %6 = fir.zero_bits !fir.heap> + %c0 = arith.constant 0 : index + %7 = fir.shape %c0 : (index) -> !fir.shape<1> + %8 = fir.embox %6(%7) : (!fir.heap>, !fir.shape<1>) -> !fir.box>> + fir.store %8 to %0 : !fir.ref>>> + %9 = fir.address_of(@_QQcl.2E2F746573746661696C2E66393000) : !fir.ref> + %c11_i32 = arith.constant 11 : i32 + %10 = fir.convert %0 : (!fir.ref>>>) -> !fir.ref> + %11 = fir.convert %5 : (!fir.box>>) -> !fir.box + %12 = fir.convert %c4 : (index) -> i32 + %13 = fir.convert %9 : (!fir.ref>) -> !fir.ref + %14 = fir.call @_FortranACountDim(%10, %11, %c2_i32, %12, %13, %c11_i32) fastmath : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none + %15 = fir.load %0 : !fir.ref>>> + %c0_2 = arith.constant 0 : index + %16:3 = fir.box_dims %15, %c0_2 : (!fir.box>>, index) -> (index, index, index) + %17 = fir.box_addr %15 : (!fir.box>>) -> !fir.heap> + %18 = fir.shape_shift %16#0, %16#1 : (index, index) -> !fir.shapeshift<1> + %19 = fir.array_load %17(%18) : (!fir.heap>, !fir.shapeshift<1>) -> !fir.array + %c1 = arith.constant 1 : index + %c0_3 = arith.constant 0 : index + %20 = arith.subi %c10_1, %c1 : index + %21 = fir.do_loop %arg1 = %c0_3 to %20 step %c1 unordered iter_args(%arg2 = %3) -> (!fir.array<10xi32>) { + %23 = fir.array_fetch %19, %arg1 : (!fir.array, index) -> i32 + %24 = fir.array_update %arg2, %23, %arg1 : (!fir.array<10xi32>, i32, index) -> !fir.array<10xi32> + fir.result %24 : !fir.array<10xi32> + } + fir.array_merge_store %3, %21 to %1 : !fir.array<10xi32>, !fir.array<10xi32>, !fir.ref> + fir.freemem %17 : !fir.heap> + %22 = fir.load %1 : !fir.ref> + return %22 : !fir.array<10xi32> +} +func.func private @_FortranACountDim(!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none attributes {fir.runtime} + +// CHECK-LABEL: func.func @_QMtestPcount_generate_mask( +// CHECK-SAME: %[[A:.*]]: !fir.ref>> {fir.bindc_name = "mask"}) -> !fir.array<10xi32> { +// CHECK-NOT fir.call @_FortranACountDim_simplified({{.*}}) +// CHECK: %[[RES:.*]] = fir.call @_FortranACountDim({{.*}}) fastmath : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32) -> none +// CHECK-NOT fir.call @_FortranACountDim_simplified({{.*}})