diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -290,6 +290,7 @@ let results = (outs AnyFloat:$imaginary); let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -436,6 +437,7 @@ let results = (outs AnyFloat:$real); let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -6,9 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::complex; @@ -98,6 +101,36 @@ return {}; } +namespace { +template +struct FoldComponentNeg final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpKind op, + PatternRewriter &rewriter) const override { + auto negOp = op.getOperand().template getDefiningOp(); + if (!negOp) + return failure(); + + auto createOp = negOp.getComplex().template getDefiningOp(); + if (!createOp) + return failure(); + + Type elementType = createOp.getType().getElementType(); + assert(isa(elementType)); + + rewriter.replaceOpWithNewOp( + op, elementType, createOp.getOperand(ComponentIndex)); + return success(); + } +}; +} // namespace + +void ImOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// @@ -111,6 +144,11 @@ return {}; } +void ReOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -155,3 +155,25 @@ %sub = complex.sub %complex1, %complex2 : complex return %sub : complex } + +// CHECK-LABEL: func @re_neg +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +func.func @re_neg(%arg0: f32, %arg1: f32) -> f32 { + %create = complex.create %arg0, %arg1: complex + // CHECK: %[[NEG:.*]] = arith.negf %[[ARG0]] + %neg = complex.neg %create : complex + %re = complex.re %neg : complex + // CHECK-NEXT: return %[[NEG]] + return %re : f32 +} + +// CHECK-LABEL: func @im_neg +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +func.func @im_neg(%arg0: f32, %arg1: f32) -> f32 { + %create = complex.create %arg0, %arg1: complex + // CHECK: %[[NEG:.*]] = arith.negf %[[ARG1]] + %neg = complex.neg %create : complex + %im = complex.im %neg : complex + // CHECK-NEXT: return %[[NEG]] + return %im : f32 +}