diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -280,6 +280,25 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SignOp +//===----------------------------------------------------------------------===// + +def SignOp : ComplexUnaryOp<"sign", [SameOperandsAndResultType]> { + let summary = "computes sign of a complex number"; + let description = [{ + The `sign` op takes a single complex number and computes the sign of + it, i.e. `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. + + Example: + + ```mlir + %a = complex.sign %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} //===----------------------------------------------------------------------===// // SubOp diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -335,15 +336,47 @@ return success(); } }; + +struct SignOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::SignOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::SignOp::Adaptor transformed(operands); + auto type = transformed.complex().getType().cast(); + auto elementType = type.getElementType().cast(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Value real = b.create(elementType, transformed.complex()); + Value imag = b.create(elementType, transformed.complex()); + Value zero = b.create(elementType, b.getZeroAttr(elementType)); + Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); + Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); + Value isZero = b.create(realIsZero, imagIsZero); + auto abs = b.create(elementType, transformed.complex()); + Value realSign = b.create(real, abs); + Value imagSign = b.create(imag, abs); + Value sign = b.create(type, realSign, imagSign); + rewriter.replaceOpWithNewOp(op, isZero, transformed.complex(), + sign); + return success(); + } +}; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { - patterns.add, - ComparisonOpConversion, - DivOpConversion, ExpOpConversion, NegOpConversion>( - patterns.getContext()); + // clang-format off + patterns.add< + AbsOpConversion, + ComparisonOpConversion, + ComparisonOpConversion, + DivOpConversion, + ExpOpConversion, + NegOpConversion, + SignOpConversion>(patterns.getContext()); + // clang-format on } namespace { @@ -363,7 +396,8 @@ target.addLegalDialect(); target.addIllegalOp(); + complex::ExpOp, complex::NotEqualOp, complex::NegOp, + complex::SignOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -181,3 +181,27 @@ // CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 // CHECK: return %[[NOT_EQUAL]] : i1 + +// CHECK-LABEL: func @complex_sign +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_sign(%arg: complex) -> complex { + %sign = complex.sign %arg: complex + return %sign : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[REAL_IS_ZERO:.*]] = cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK: %[[IMAG_IS_ZERO:.*]] = cmpf oeq, %1, %cst : f32 +// CHECK: %[[IS_ZERO:.*]] = and %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL2]], %[[REAL2]] : f32 +// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG2]], %[[IMAG2]] : f32 +// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[REAL_SIGN:.*]] = divf %[[REAL]], %[[NORM]] : f32 +// CHECK: %[[IMAG_SIGN:.*]] = divf %[[IMAG]], %[[NORM]] : f32 +// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex +// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex +// CHECK: return %[[RESULT]] : complex