diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -295,6 +295,11 @@ //===----------------------------------------------------------------------===// OpFoldResult AddFOp::fold(ArrayRef operands) { + // addf(x,0) -> x + if (FloatAttr rhs = operands[1].dyn_cast_or_null()) + if (rhs.getValue().isZero()) + return lhs(); + return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a + b; }); } @@ -2756,6 +2761,14 @@ //===----------------------------------------------------------------------===// OpFoldResult SubFOp::fold(ArrayRef operands) { + // subf(x,x) -> 0 + if (getOperand(0) == getOperand(1)) + return Builder(getContext()).getZeroAttr(getType()); + // subf(x,0) -> x + if (FloatAttr rhs = operands[1].dyn_cast_or_null()) + if (rhs.getValue().isZero()) + return lhs(); + return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a - b; }); } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -252,3 +252,33 @@ %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> return %res : tensor<4x6x16x32xi8> } + +// ----- + +// CHECK-LABEL: func @addf_zero +// CHECK-SAME: %[[ARG:.[a-z0-9A-Z_]+]]: f32 +// CHECK: return %[[ARG]] : f32 +func @addf_zero(%arg0: f32) -> f32 { + %cst = constant 0.0 : f32 + %0 = addf %arg0, %cst : f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @subf_zero +// CHECK-SAME: %[[ARG:.[a-z0-9A-Z_]+]]: f32 +// CHECK: return %[[ARG]] : f32 +func @subf_zero(%arg0: f32) -> f32 { + %cst = constant 0.0 : f32 + %0 = subf %arg0, %cst : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @subf_same +// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32 +// CHECK: return %[[ZERO]] : f32 +func @subf_same(%arg0: f32) -> f32 { + %0 = subf %arg0, %arg0 : f32 + return %0 : f32 +}