diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -584,4 +584,25 @@ let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// ArgOp +//===----------------------------------------------------------------------===// + +def ArgOp : ComplexUnaryOp<"arg", + [TypesMatchWith<"complex element type matches result type", + "complex", "result", + "$_self.cast().getElementType()">]> { + let summary = "computes argument value of a complex number"; + let description = [{ + The `arg` op takes a single complex number and computes its argument value with a branch cut along the negative real axis. + + Example: + + ```mlir + %a = complex.arg %b : complex + ``` + }]; + let results = (outs AnyFloat:$result); +} + #endif // COMPLEX_OPS diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -1009,6 +1009,26 @@ } }; +struct ArgOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ArgOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = op.getType(); + + Value real = + rewriter.create(loc, type, adaptor.getComplex()); + Value imag = + rewriter.create(loc, type, adaptor.getComplex()); + + rewriter.replaceOpWithNewOp(op, imag, real); + + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -1036,7 +1056,8 @@ TanOpConversion, TanhOpConversion, PowOpConversion, - RsqrtOpConversion + RsqrtOpConversion, + ArgOpConversion >(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -694,3 +694,16 @@ %rsqrt = complex.rsqrt %arg : complex return %rsqrt : complex } + +// ----- + +// CHECK-LABEL: func.func @complex_arg +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_arg(%arg: complex) -> f32 { + %angle = complex.arg %arg : complex + return %angle : f32 +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32 +// CHECK: return %[[RESULT]] : f32 \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -82,6 +82,27 @@ func.return %pow : complex } +func.func @test_element(%input: tensor>, + %func: (complex) -> f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %size = tensor.dim %input, %c0: tensor> + + scf.for %i = %c0 to %size step %c1 { + %elem = tensor.extract %input[%i]: tensor> + + %val = func.call_indirect %func(%elem) : (complex) -> f32 + vector.print %val : f32 + scf.yield + } + func.return +} + +func.func @arg(%arg: complex) -> f32 { + %angle = complex.arg %arg : complex + func.return %angle : f32 +} + func.func @entry() { // complex.sqrt test %sqrt_test = arith.constant dense<[ @@ -251,6 +272,30 @@ %conj_func = func.constant @conj : (complex) -> complex call @test_unary(%conj_test_cast, %conj_func) : (tensor>, (complex) -> complex) -> () - + + // complex.arg test + %arg_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: -2.356 + (-1.0, 1.0), + // CHECK-NEXT: 2.356 + (0.0, 0.0), + // CHECK-NEXT: 0 + (0.0, 1.0), + // CHECK-NEXT: 1.570 + (1.0, -1.0), + // CHECK-NEXT: -0.785 + (1.0, 0.0), + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 0.785 + ]> : tensor<7xcomplex> + %arg_test_cast = tensor.cast %arg_test + : tensor<7xcomplex> to tensor> + + %arg_func = func.constant @arg : (complex) -> f32 + call @test_element(%arg_test_cast, %arg_func) + : (tensor>, (complex) -> f32) -> () + func.return }