diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -549,6 +549,9 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::NegFOp::fold(ArrayRef operands) { + /// negf(negf(x)) -> x + if (auto op = this->getOperand().getDefiningOp()) + return op.getOperand(); return constFoldUnaryOp(operands, [](const APFloat &a) { return -a; }); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1320,6 +1320,15 @@ return %0: f32 } +// CHECK-LABEL: @test_negf1( +// CHECK-SAME: %[[arg0:.+]]: +// CHECK: return %[[arg0]] +func.func @test_negf1(%f : f32) -> (f32) { + %0 = arith.negf %f : f32 + %1 = arith.negf %0 : f32 + return %1: f32 +} + // ----- // CHECK-LABEL: @test_remui(