diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -88,7 +88,7 @@ let summary = "complex number creation operation"; let description = [{ - The `complex.complex` operation creates a complex number from two + The `complex.create` operation creates a complex number from two floating-point operands, the real and the imaginary part. Example: @@ -102,6 +102,7 @@ let results = (outs Complex:$complex); let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -19,22 +19,35 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" -OpFoldResult ReOp::fold(ArrayRef operands) { +OpFoldResult CreateOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary op takes two operands"); + // Fold complex.create(complex.re(op), complex.im(op)). + if (auto reOp = getOperand(0).getDefiningOp()) { + if (auto imOp = getOperand(1).getDefiningOp()) { + if (reOp.getOperand() == imOp.getOperand()) { + return reOp.getOperand(); + } + } + } + return {}; +} + +OpFoldResult ImOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary op takes 1 operand"); ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) - return arrayAttr[0]; + return arrayAttr[1]; if (auto createOp = getOperand().getDefiningOp()) - return createOp.getOperand(0); + return createOp.getOperand(1); return {}; } -OpFoldResult ImOp::fold(ArrayRef operands) { +OpFoldResult ReOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "unary op takes 1 operand"); ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) - return arrayAttr[1]; + return arrayAttr[0]; if (auto createOp = getOperand().getDefiningOp()) - return createOp.getOperand(1); + return createOp.getOperand(0); return {}; } diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -1,5 +1,28 @@ // RUN: mlir-opt %s -canonicalize | FileCheck %s +// CHECK-LABEL: func @create_of_real_and_imag +// CHECK-SAME: (%[[CPLX:.*]]: complex) +func @create_of_real_and_imag(%cplx: complex) -> complex { + // CHECK-NEXT: return %[[CPLX]] : complex + %real = complex.re %cplx : complex + %imag = complex.im %cplx : complex + %complex = complex.create %real, %imag : complex + return %complex : complex +} + +// CHECK-LABEL: func @create_of_real_and_imag_different_operand +// CHECK-SAME: (%[[CPLX:.*]]: complex, %[[CPLX2:.*]]: complex) +func @create_of_real_and_imag_different_operand( + %cplx: complex, %cplx2 : complex) -> complex { + // CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX]] : complex + // CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX2]] : complex + // CHECK-NEXT: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex + %real = complex.re %cplx : complex + %imag = complex.im %cplx2 : complex + %complex = complex.create %real, %imag : complex + return %complex: complex +} + // CHECK-LABEL: func @real_of_const( func @real_of_const() -> f32 { // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32