diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -351,6 +351,8 @@ %a = complex.mul %b, %c : complex ``` }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -248,6 +248,41 @@ return {}; } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + if (auto constant = getRhs().getDefiningOp()) { + auto arrayAttr = constant.getValue(); + auto real = llvm::cast(arrayAttr[0]).getValue(); + auto imag = llvm::cast(arrayAttr[1]).getValue(); + + if (imag.isZero()) { + // complex.mul(a, complex.constant<0.0, 0.0>) + // -> complex.constant<0.0, 0.0> + if (real.isZero()) + return getRhs(); + + // complex.mul(a, complex.constant<1.0, 0.0>) -> a + switch (real.getSizeInBits(real.getSemantics())) { + case 64: + if (real == APFloat(1.0)) + return getLhs(); + return {}; + case 32: + if (real == APFloat(1.0f)) + return getLhs(); + return {}; + default: + return {}; + } + } + } + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -177,3 +177,36 @@ // CHECK-NEXT: return %[[NEG]] return %im : f32 } + +// CHECK-LABEL: func @mul_zero +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex +func.func @mul_zero(%arg0: f32, %arg1: f32) -> complex { + %create = complex.create %arg0, %arg1: complex + %zero = complex.constant [0.0 : f32, 0.0 : f32] : complex + %mul = complex.mul %create, %zero : complex + // CHECK: %[[ZERO:.*]] = complex.constant [0.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[ZERO]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex +func.func @mul_one_f32(%arg0: f32, %arg1: f32) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f32, 0.0 : f32] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f64 +// CHECK-SAME: (%[[ARG0:.*]]: f64, %[[ARG1:.*]]: f64) -> complex +func.func @mul_one_f64(%arg0: f64, %arg1: f64) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f64, 0.0 : f64] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +}