diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -35,13 +35,18 @@ auto loc = op.getLoc(); auto type = op.getType(); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); + Value real = rewriter.create(loc, type, adaptor.getComplex()); Value imag = rewriter.create(loc, type, adaptor.getComplex()); - Value realSqr = rewriter.create(loc, real, real); - Value imagSqr = rewriter.create(loc, imag, imag); - Value sqNorm = rewriter.create(loc, realSqr, imagSqr); + Value realSqr = + rewriter.create(loc, real, real, fmf.getValue()); + Value imagSqr = + rewriter.create(loc, imag, imag, fmf.getValue()); + Value sqNorm = + rewriter.create(loc, realSqr, imagSqr, fmf.getValue()); rewriter.replaceOpWithNewOp(op, sqNorm); return success(); diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -707,3 +707,19 @@ // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32 // CHECK: return %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @complex_abs_with_fmf +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_abs_with_fmf(%arg: complex) -> f32 { + %abs = complex.abs %arg fastmath : complex + return %abs : f32 +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: return %[[NORM]] : f32 \ No newline at end of file