diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -764,6 +764,7 @@ math, contraction, rounding mode, and other controls. }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -773,6 +774,7 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> { let summary = "floating point division operation"; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticCanonicalization.td @@ -187,4 +187,22 @@ Pat<(Arith_OrIOp (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)), (Arith_ExtSIOp (Arith_OrIOp $x, $y)), [(Constraint> $x, $y)]>; +//===----------------------------------------------------------------------===// +// MulFOp +//===----------------------------------------------------------------------===// + +// mulf(negf(x), negf(y)) -> mulf(x,y) +def MulFOfNegF : + Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y), + [(Constraint> $x, $y)]>; + +//===----------------------------------------------------------------------===// +// DivFOp +//===----------------------------------------------------------------------===// + +// divf(negf(x), negf(y)) -> divf(x,y) +def DivFOfNegF : + Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y), + [(Constraint> $x, $y)]>; + #endif // ARITHMETIC_PATTERNS diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -743,6 +743,11 @@ operands, [](const APFloat &a, const APFloat &b) { return a * b; }); } +void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// @@ -756,6 +761,11 @@ operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } +void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -948,6 +948,16 @@ return %0, %1, %2, %3 : f32, f32, f32, f32 } +// CHECK-LABEL: @test_mulf1( +func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { + // CHECK-NEXT: %[[X:.+]] = arith.mulf %arg0, %arg1 : f32 + // CHECK-NEXT: return %[[X]] + %0 = arith.negf %arg0 : f32 + %1 = arith.negf %arg1 : f32 + %2 = arith.mulf %0, %1 : f32 + return %2 : f32 +} + // ----- // CHECK-LABEL: @test_divf( @@ -961,6 +971,16 @@ return %0, %1 : f64, f64 } +// CHECK-LABEL: @test_divf1( +func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) { + // CHECK-NEXT: %[[X:.+]] = arith.divf %arg0, %arg1 : f32 + // CHECK-NEXT: return %[[X]] + %0 = arith.negf %arg0 : f32 + %1 = arith.negf %arg1 : f32 + %2 = arith.divf %0, %1 : f32 + return %2 : f32 +} + // ----- // CHECK-LABEL: @test_cmpf(