diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -144,6 +144,7 @@ let results = (outs AnyFloat:$imaginary); let assemblyFormat = "$complex attr-dict `:` type($complex)"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -185,6 +186,7 @@ let results = (outs AnyFloat:$real); let assemblyFormat = "$complex attr-dict `:` type($complex)"; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -17,3 +17,17 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" + +OpFoldResult ReOp::fold(ArrayRef operands) { + ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); + if (arrayAttr) + return arrayAttr[0]; + return {}; +} + +OpFoldResult ImOp::fold(ArrayRef operands) { + ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); + if (arrayAttr) + return arrayAttr[1]; + return {}; +} diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: func @real_of_const( +func @real_of_const() -> f32 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK-NEXT: return %[[CST]] : f32 + %complex = constant [1.0 : f32, 0.0 : f32] : complex + %1 = complex.re %complex : complex + return %1 : f32 +} + +// CHECK-LABEL: func @imag_of_const( +func @imag_of_const() -> f32 { + // CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 + // CHECK-NEXT: return %[[CST]] : f32 + %complex = constant [1.0 : f32, 0.0 : f32] : complex + %1 = complex.im %complex : complex + return %1 : f32 +}