diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -44,6 +44,49 @@ } }; +// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) +struct Atan2OpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto type = op.getType().cast(); + Type elementType = type.getElementType(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + Value rhsSquared = b.create(type, rhs, rhs); + Value lhsSquared = b.create(type, lhs, lhs); + Value rhsSquaredPlusLhsSquared = + b.create(type, rhsSquared, lhsSquared); + Value sqrtOfRhsSquaredPlusLhsSquared = + b.create(type, rhsSquaredPlusLhsSquared); + + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value i = b.create(type, zero, one); + Value iTimesLhs = b.create(i, lhs); + Value rhsPlusILhs = b.create(rhs, iTimesLhs); + + Value divResult = + b.create(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared); + Value logResult = b.create(divResult); + + Value negativeOne = b.create( + elementType, b.getFloatAttr(elementType, -1)); + Value negativeI = b.create(type, zero, negativeOne); + + rewriter.replaceOpWithNewOp(op, negativeI, logResult); + return success(); + } +}; + template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -700,6 +743,73 @@ } }; +// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. +struct SqrtOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto type = op.getType().cast(); + Type elementType = type.getElementType(); + Value arg = adaptor.getComplex(); + + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + + Value real = b.create(elementType, adaptor.getComplex()); + Value imag = b.create(elementType, adaptor.getComplex()); + + Value absLhs = b.create(real); + Value absArg = b.create(elementType, arg); + Value addAbs = b.create(absLhs, absArg); + + Value half = b.create( + elementType, b.getFloatAttr(elementType, 0.5)); + Value halfAddAbs = b.create(addAbs, half); + Value sqrtAddAbs = b.create(halfAddAbs); + + Value realIsNegative = + b.create(arith::CmpFPredicate::OLT, real, zero); + Value imagIsNegative = + b.create(arith::CmpFPredicate::OLT, imag, zero); + + Value resultReal = sqrtAddAbs; + + Value imagDivTwoResultReal = b.create( + imag, b.create(resultReal, resultReal)); + + Value negativeResultReal = b.create(resultReal); + + Value resultImag = b.create( + realIsNegative, + b.create(imagIsNegative, negativeResultReal, + resultReal), + imagDivTwoResultReal); + + resultReal = b.create( + realIsNegative, + b.create( + imag, b.create(resultImag, resultImag)), + resultReal); + + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + Value argIsZero = b.create(realIsZero, imagIsZero); + + resultReal = b.create(argIsZero, zero, resultReal); + resultImag = b.create(argIsZero, zero, resultImag); + + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; + struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -782,6 +892,7 @@ // clang-format off patterns.add< AbsOpConversion, + Atan2OpConversion, ComparisonOpConversion, ComparisonOpConversion, BinaryComplexOpConversion, @@ -796,6 +907,7 @@ NegOpConversion, SignOpConversion, SinOpConversion, + SqrtOpConversion, TanOpConversion, TanhOpConversion>(patterns.getContext()); // clang-format on diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s +// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\ +// RUN: FileCheck %s // CHECK-LABEL: func @complex_abs // CHECK-SAME: %[[ARG:.*]]: complex @@ -16,6 +17,15 @@ // ----- +// CHECK-LABEL: func @complex_atan2 +func.func @complex_atan2(%lhs: complex, + %rhs: complex) -> complex { + %atan2 = complex.atan2 %lhs, %rhs : complex + return %atan2 : complex +} + +// ----- + // CHECK-LABEL: func @complex_add // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_add(%lhs: complex, %rhs: complex) -> complex { @@ -645,3 +655,11 @@ // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32 // CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex + +// ----- + +// CHECK-LABEL: func @complex_sqrt +func.func @complex_sqrt(%arg: complex) -> complex { + %sqrt = complex.sqrt %arg : complex + return %sqrt : complex +} diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s \ +// RUN: -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \ +// RUN: -convert-scf-to-cf --convert-complex-to-standard \ +// RUN: -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \ +// RUN: -convert-vector-to-llvm -convert-complex-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts |\ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\ +// RUN: FileCheck %s + +func.func @test_unary(%input: tensor>, + %func: (complex) -> complex) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor> + + %val = func.call_indirect %func(%elem) : (complex) -> complex + %real = complex.re %val : complex + %imag = complex.im %val: complex + vector.print %real : f32 + vector.print %imag : f32 + scf.yield + } + func.return +} + +func.func @sqrt(%arg: complex) -> complex { + %sqrt = complex.sqrt %arg : complex + func.return %sqrt : complex +} + +// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] +func.func @test_binary(%input: tensor>, + %func: (complex, complex) -> complex) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c2 { + %lhs = tensor.extract %input[%i]: tensor> + %i_next = arith.addi %i, %c1 : index + %rhs = tensor.extract %input[%i_next]: tensor> + + %val = func.call_indirect %func(%lhs, %rhs) + : (complex, complex) -> complex + %real = complex.re %val : complex + %imag = complex.im %val: complex + vector.print %real : f32 + vector.print %imag : f32 + scf.yield + } + func.return +} + +func.func @atan2(%lhs: complex, %rhs: complex) -> complex { + %atan2 = complex.atan2 %lhs, %rhs : complex + func.return %atan2 : complex +} + + +func.func @entry() { + // complex.sqrt test + %sqrt_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: 0.455 + // CHECK-NEXT: -1.098 + (-1.0, 1.0), + // CHECK-NEXT: 0.455 + // CHECK-NEXT: 1.098 + (0.0, 0.0), + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + (0.0, 1.0), + // CHECK-NEXT: 0.707 + // CHECK-NEXT: 0.707 + (1.0, -1.0), + // CHECK-NEXT: 1.098 + // CHECK-NEXT: -0.455 + (1.0, 0.0), + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 1.098 + // CHECK-NEXT: 0.455 + ]> : tensor<7xcomplex> + %sqrt_test_cast = tensor.cast %sqrt_test + : tensor<7xcomplex> to tensor> + + %sqrt_func = func.constant @sqrt : (complex) -> complex + call @test_unary(%sqrt_test_cast, %sqrt_func) + : (tensor>, (complex) -> complex) -> () + + // complex.atan2 test + %atan2_test = arith.constant dense<[ + (1.0, 2.0), (2.0, 1.0), + // CHECK: 0.785 + // CHECK-NEXT: 0.346 + (1.0, 1.0), (1.0, 0.0), + // CHECK-NEXT: 1.017 + // CHECK-NEXT: 0.402 + (1.0, 1.0), (1.0, 1.0) + // CHECK-NEXT: 0.785 + // CHECK-NEXT: 0 + ]> : tensor<6xcomplex> + %atan2_test_cast = tensor.cast %atan2_test + : tensor<6xcomplex> to tensor> + + %atan2_func = func.constant @atan2 : (complex, complex) + -> complex + call @test_binary(%atan2_test_cast, %atan2_func) + : (tensor>, (complex, complex) + -> complex) -> () + func.return +}