diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -240,6 +240,8 @@ }]; let results = (outs Complex:$result); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -309,6 +311,8 @@ }]; let results = (outs Complex:$result); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -138,6 +138,34 @@ return {}; } +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + +OpFoldResult LogOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary op takes 1 operand"); + + // complex.log(complex.exp(a)) -> a + if (auto expOp = getOperand().getDefiningOp()) + return expOp.getOperand(); + + return {}; +} + +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +OpFoldResult ExpOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "unary op takes 1 operand"); + + // complex.exp(complex.log(a)) -> a + if (auto logOp = getOperand().getDefiningOp()) + return logOp.getOperand(); + + return {}; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -93,4 +93,24 @@ %neg1 = complex.neg %complex1 : complex %neg2 = complex.neg %neg1 : complex return %neg2 : complex +} + +// CHECK-LABEL: func @complex_log_exp +func.func @complex_log_exp() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %exp = complex.exp %complex1 : complex + %log = complex.log %exp : complex + return %log : complex +} + +// CHECK-LABEL: func @complex_exp_log +func.func @complex_exp_log() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %log = complex.log %complex1 : complex + %exp = complex.exp %log : complex + return %exp : complex } \ No newline at end of file