diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -109,6 +109,26 @@ }]; } +//===----------------------------------------------------------------------===// +// CosOp +//===----------------------------------------------------------------------===// + +def CosOp : ComplexUnaryOp<"cos", [SameOperandsAndResultType]> { + let summary = "computes cosine of a complex number"; + let description = [{ + The `cos` op takes a single complex number and computes the cosine of + it, i.e. `cos(x)`, where `x` is the input value. + + Example: + + ```mlir + %a = complex.cos %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // CreateOp //===----------------------------------------------------------------------===// @@ -369,6 +389,26 @@ let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// SinOp +//===----------------------------------------------------------------------===// + +def SinOp : ComplexUnaryOp<"sin", [SameOperandsAndResultType]> { + let summary = "computes sine of a complex number"; + let description = [{ + The `sin` op takes a single complex number and computes the sine of + it, i.e. `sin(x)`, where `x` is the input value. + + Example: + + ```mlir + %a = complex.sin %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -103,6 +103,69 @@ } }; +template +struct TrigonometricOpConversion : public OpConversionPattern { + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = adaptor.getComplex().getType().template cast(); + auto elementType = type.getElementType().template cast(); + + Value real = + rewriter.create(loc, elementType, adaptor.getComplex()); + Value imag = + rewriter.create(loc, elementType, adaptor.getComplex()); + + // Trigonometric ops use a set of common building blocks to convert to real + // ops. Here we create these building blocks and call through CRTP into an + // op-specific implementation to combine them. + Value half = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 0.5)); + Value exp = rewriter.create(loc, imag); + Value scaledExp = rewriter.create(loc, half, exp); + Value reciprocalExp = rewriter.create(loc, half, exp); + Value sin = rewriter.create(loc, real); + Value cos = rewriter.create(loc, real); + + auto resultPair = static_cast(this)->combine( + loc, scaledExp, reciprocalExp, sin, cos, rewriter); + + rewriter.replaceOpWithNewOp(op, type, resultPair.first, + resultPair.second); + return success(); + } +}; + +struct CosOpConversion + : public TrigonometricOpConversion { + using TrigonometricOpConversion::TrigonometricOpConversion; + + std::pair combine(Location loc, Value scaledExp, + Value reciprocalExp, Value sin, Value cos, + ConversionPatternRewriter &rewriter) const { + // Complex cosine is defined as; + // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy))) + // Plugging in: + // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) + // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) + // and defining t := exp(y) + // We get: + // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x + // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x + Value sum = rewriter.create(loc, reciprocalExp, scaledExp); + Value resultReal = rewriter.create(loc, sum, cos); + Value diff = rewriter.create(loc, reciprocalExp, scaledExp); + Value resultImag = rewriter.create(loc, diff, sin); + return {resultReal, resultImag}; + } +}; + struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -588,6 +651,31 @@ } }; +struct SinOpConversion + : public TrigonometricOpConversion { + using TrigonometricOpConversion::TrigonometricOpConversion; + + std::pair combine(Location loc, Value scaledExp, + Value reciprocalExp, Value sin, Value cos, + ConversionPatternRewriter &rewriter) const { + // Complex sine is defined as; + // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy))) + // Plugging in: + // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x)) + // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x))) + // and defining t := exp(y) + // We get: + // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x + // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x + Value sum = rewriter.create(loc, scaledExp, reciprocalExp); + Value resultReal = rewriter.create(loc, sum, sin); + Value diff = rewriter.create(loc, scaledExp, reciprocalExp); + Value resultImag = rewriter.create(loc, diff, cos); + return {resultReal, resultImag}; + } +}; + struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -627,13 +715,15 @@ ComparisonOpConversion, BinaryComplexOpConversion, BinaryComplexOpConversion, + CosOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, Log1pOpConversion, MulOpConversion, NegOpConversion, - SignOpConversion>(patterns.getContext()); + SignOpConversion, + SinOpConversion>(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -29,6 +29,27 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-LABEL: func @complex_cos +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_cos(%arg: complex) -> complex { + %cos = complex.cos %arg : complex + return %cos : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK: %[[EXP:.*]] = math.exp %[[IMAG]] : f32 +// CHECK: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] : f32 +// CHECK: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] : f32 +// CHECK: %[[SIN:.*]] = math.sin %[[REAL]] : f32 +// CHECK: %[[COS:.*]] = math.cos %[[REAL]] : f32 +// CHECK: %[[EXP_SUM:.*]] = arith.addf %[[HALF_REXP]], %[[HALF_EXP]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[COS]] : f32 +// CHECK: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_REXP]], %[[HALF_EXP]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[SIN]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_div // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { @@ -358,6 +379,27 @@ // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 // CHECK: return %[[NOT_EQUAL]] : i1 +// CHECK-LABEL: func @complex_sin +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_sin(%arg: complex) -> complex { + %sin = complex.sin %arg : complex + return %sin : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK: %[[EXP:.*]] = math.exp %[[IMAG]] : f32 +// CHECK: %[[HALF_EXP:.*]] = arith.mulf %[[HALF]], %[[EXP]] : f32 +// CHECK: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] : f32 +// CHECK: %[[SIN:.*]] = math.sin %[[REAL]] : f32 +// CHECK: %[[COS:.*]] = math.cos %[[REAL]] : f32 +// CHECK: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] : f32 +// CHECK: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] : f32 +// CHECK: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_sign // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_sign(%arg: complex) -> complex { diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir --- a/mlir/test/Dialect/Complex/ops.mlir +++ b/mlir/test/Dialect/Complex/ops.mlir @@ -26,6 +26,9 @@ // CHECK: complex.add %[[C]], %[[C]] : complex %sum = complex.add %complex, %complex : complex + // CHECK: complex.cos %[[C]] : complex + %cos = complex.cos %complex : complex + // CHECK: complex.div %[[C]], %[[C]] : complex %div = complex.div %complex, %complex : complex @@ -53,6 +56,9 @@ // CHECK: complex.sign %[[C]] : complex %sign = complex.sign %complex : complex + // CHECK: complex.sin %[[C]] : complex + %sin = complex.sin %complex : complex + // CHECK: complex.sub %[[C]], %[[C]] : complex %diff = complex.sub %complex, %complex : complex return