diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -351,6 +351,8 @@ %a = complex.mul %b, %c : complex ``` }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -248,6 +248,29 @@ return {}; } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto constant = getRhs().getDefiningOp(); + if (!constant) + return {}; + + ArrayAttr arrayAttr = constant.getValue(); + APFloat real = cast(arrayAttr[0]).getValue(); + APFloat imag = cast(arrayAttr[1]).getValue(); + + if (!imag.isZero()) + return {}; + + // complex.mul(a, complex.constant<1.0, 0.0>) -> a + if (real == APFloat(real.getSemantics(), 1)) + return getLhs(); + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -177,3 +177,58 @@ // CHECK-NEXT: return %[[NEG]] return %im : f32 } + +// CHECK-LABEL: func @mul_one_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> complex +func.func @mul_one_f16(%arg0: f16, %arg1: f16) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f16, 0.0 : f16] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex +func.func @mul_one_f32(%arg0: f32, %arg1: f32) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f32, 0.0 : f32] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f64 +// CHECK-SAME: (%[[ARG0:.*]]: f64, %[[ARG1:.*]]: f64) -> complex +func.func @mul_one_f64(%arg0: f64, %arg1: f64) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f64, 0.0 : f64] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f80 +// CHECK-SAME: (%[[ARG0:.*]]: f80, %[[ARG1:.*]]: f80) -> complex +func.func @mul_one_f80(%arg0: f80, %arg1: f80) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f80, 0.0 : f80] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +} + +// CHECK-LABEL: func @mul_one_f128 +// CHECK-SAME: (%[[ARG0:.*]]: f128, %[[ARG1:.*]]: f128) -> complex +func.func @mul_one_f128(%arg0: f128, %arg1: f128) -> complex { + %create = complex.create %arg0, %arg1: complex + %one = complex.constant [1.0 : f128, 0.0 : f128] : complex + %mul = complex.mul %create, %one : complex + // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // CHECK-NEXT: return %[[CREATE]] + return %mul : complex +}