diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -419,6 +419,9 @@ /// config. void setFastMathFlags(Fortran::common::MathOptionsBase options); + /// Get current FastMathFlags value. + mlir::arith::FastMathFlags getFastMathFlags() const { return fastMathFlags; } + /// Dump the current function. (debug) LLVM_DUMP_METHOD void dumpFunc(); diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -85,6 +85,35 @@ } // namespace +/// Create FirOpBuilder with the provided \p op insertion point +/// and \p kindMap additionally inheriting FastMathFlags from \p op. +static fir::FirOpBuilder +getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) { + fir::FirOpBuilder builder{op, kindMap}; + auto fmi = mlir::dyn_cast(*op); + if (!fmi) + return builder; + + // Regardless of what default FastMathFlags are used by FirOpBuilder, + // override them with FastMathFlags attached to the operation. + builder.setFastMathFlags(fmi.getFastMathFlagsAttr().getValue()); + return builder; +} + +/// Stringify FastMathFlags set for the given \p builder in a way +/// that the string may be used for mangling a function name. +/// If FastMathFlags are set to 'none', then the result is an empty +/// string. +static std::string getFastMathFlagsString(const fir::FirOpBuilder &builder) { + mlir::arith::FastMathFlags flags = builder.getFastMathFlags(); + if (flags == mlir::arith::FastMathFlags::none) + return {}; + + std::string fmfString{mlir::arith::stringifyFastMathFlags(flags)}; + std::replace(fmfString.begin(), fmfString.end(), ',', '_'); + return fmfString; +} + /// Generate function type for the simplified version of RTNAME(Sum) and /// similar functions with a fir.box type returning \p elementType. static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder, @@ -511,7 +540,8 @@ unsigned rank = getDimCount(args[0]); if (dimAndMaskAbsent && rank > 0) { mlir::Location loc = call.getLoc(); - fir::FirOpBuilder builder(call, kindMap); + fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; + std::string fmfString{getFastMathFlagsString(builder)}; // Support only floating point and integer results now. mlir::Type resultType = call.getResult(0).getType(); @@ -535,7 +565,10 @@ // Mangle the function name with the rank value as "x". std::string funcName = (mlir::Twine{callee.getLeafReference().getValue(), "x"} + - mlir::Twine{rank}) + mlir::Twine{rank} + + // We must mangle the generated function name with FastMathFlags + // value. + (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString})) .str(); mlir::func::FuncOp newFunc = getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); @@ -576,7 +609,10 @@ const mlir::Value &v1 = args[0]; const mlir::Value &v2 = args[1]; mlir::Location loc = call.getLoc(); - fir::FirOpBuilder builder(op, kindMap); + fir::FirOpBuilder builder{getSimplificationBuilder(op, kindMap)}; + // Stringize the builder's FastMathFlags flags for mangling + // the generated function name. + std::string fmfString{getFastMathFlagsString(builder)}; mlir::Type type = call.getResult(0).getType(); if (!type.isa() && !type.isa()) @@ -611,9 +647,13 @@ // of the arguments. std::string typedFuncName(funcName); llvm::raw_string_ostream nameOS(typedFuncName); - nameOS << "_"; + // We must mangle the generated function name with FastMathFlags + // value. + if (!fmfString.empty()) + nameOS << '_' << fmfString; + nameOS << '_'; arg1Type->print(nameOS); - nameOS << "_"; + nameOS << '_'; arg2Type->print(nameOS); mlir::func::FuncOp newFunc = getOrCreateFunction( diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -998,3 +998,103 @@ // CHECK-NOT: call{{.*}}_FortranASumInteger8( // CHECK: call @_FortranASumInteger8x2_simplified( // CHECK-NOT: call{{.*}}_FortranASumInteger8( + +// ----- + +func.func @dot_f32_contract_reassoc(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}) -> f32 { + %0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"} + %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref> + %c3_i32 = arith.constant 3 : i32 + %2 = fir.convert %arg0 : (!fir.box>) -> !fir.box + %3 = fir.convert %arg1 : (!fir.box>) -> !fir.box + %4 = fir.convert %1 : (!fir.ref>) -> !fir.ref + %5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath : (!fir.box, !fir.box, !fir.ref, i32) -> f32 + fir.store %5 to %0 : !fir.ref + %6 = fir.load %0 : !fir.ref + return %6 : f32 +} + +func.func @dot_f32_fast(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "b"}) -> f32 { + %0 = fir.alloca f32 {bindc_name = "dot", uniq_name = "_QFdotEdot"} + %1 = fir.address_of(@_QQcl.2E2F646F742E66393000) : !fir.ref> + %c3_i32 = arith.constant 3 : i32 + %2 = fir.convert %arg0 : (!fir.box>) -> !fir.box + %3 = fir.convert %arg1 : (!fir.box>) -> !fir.box + %4 = fir.convert %1 : (!fir.ref>) -> !fir.ref + %5 = fir.call @_FortranADotProductReal4(%2, %3, %4, %c3_i32) fastmath : (!fir.box, !fir.box, !fir.ref, i32) -> f32 + fir.store %5 to %0 : !fir.ref + %6 = fir.load %0 : !fir.ref + return %6 : f32 +} + +func.func private @_FortranADotProductReal4(!fir.box, !fir.box, !fir.ref, i32) -> f32 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F646F742E66393000 constant : !fir.char<1,10> { + %0 = fir.string_lit "./dot.f90\00"(10) : !fir.char<1,10> + fir.has_value %0 : !fir.char<1,10> +} + +// CHECK-LABEL: @dot_f32_contract_reassoc +// CHECK: fir.call @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified(%2, %3) fastmath +// CHECK-LABEL: @dot_f32_fast +// CHECK: fir.call @_FortranADotProductReal4_fast_f32_f32_simplified(%2, %3) fastmath +// CHECK-LABEL: func.func private @_FortranADotProductReal4_reassoc_contract_f32_f32_simplified +// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath : f32 +// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f32 +// CHECK-LABEL: func.func private @_FortranADotProductReal4_fast_f32_f32_simplified +// CHECK: arith.mulf %{{.*}}, %{{.*}} fastmath : f32 +// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f32 + +// ----- + +func.func @sum_1d_real_contract_reassoc(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> f64 { + %c10 = arith.constant 10 : index + %0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"} + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + %3 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %5 = fir.convert %2 : (!fir.box>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.convert %3 : (!fir.box) -> !fir.box + %9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath : (!fir.box, !fir.ref, i32, i32, !fir.box) -> f64 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : f64 +} + +func.func @sum_1d_real_fast(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> f64 { + %c10 = arith.constant 10 : index + %0 = fir.alloca f64 {bindc_name = "sum_1d_real", uniq_name = "_QFsum_1d_realEsum_1d_real"} + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = fir.embox %arg0(%1) : (!fir.ref>, !fir.shape<1>) -> !fir.box> + %3 = fir.absent !fir.box + %c0 = arith.constant 0 : index + %4 = fir.address_of(@_QQcl.2E2F6973756D5F352E66393000) : !fir.ref> + %c5_i32 = arith.constant 5 : i32 + %5 = fir.convert %2 : (!fir.box>) -> !fir.box + %6 = fir.convert %4 : (!fir.ref>) -> !fir.ref + %7 = fir.convert %c0 : (index) -> i32 + %8 = fir.convert %3 : (!fir.box) -> !fir.box + %9 = fir.call @_FortranASumReal8(%5, %6, %c5_i32, %7, %8) fastmath : (!fir.box, !fir.ref, i32, i32, !fir.box) -> f64 + fir.store %9 to %0 : !fir.ref + %10 = fir.load %0 : !fir.ref + return %10 : f64 +} + +func.func private @_FortranASumReal8(!fir.box, !fir.ref, i32, i32, !fir.box) -> f64 attributes {fir.runtime} +fir.global linkonce @_QQcl.2E2F6973756D5F352E66393000 constant : !fir.char<1,13> { + %0 = fir.string_lit "./isum_5.f90\00"(13) : !fir.char<1,13> + fir.has_value %0 : !fir.char<1,13> +} + +// CHECK-LABEL: @sum_1d_real_contract_reassoc +// CHECK: fir.call @_FortranASumReal8x1_reassoc_contract_simplified(%5) fastmath +// CHECK-LABEL: @sum_1d_real_fast +// CHECK: fir.call @_FortranASumReal8x1_fast_simplified(%5) fastmath +// CHECK-LABEL: func.func private @_FortranASumReal8x1_reassoc_contract_simplified +// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f64 +// CHECK-LABEL: func.func private @_FortranASumReal8x1_fast_simplified +// CHECK: arith.addf %{{.*}}, %{{.*}} fastmath : f64