diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -906,6 +906,105 @@ } }; +// Coverts (a+bi)^(c+di) to +// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), +// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) +static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, + ComplexType type, Value a, Value b, Value c, + Value d) { + auto elementType = type.getElementType().cast(); + Value aa_p_bb = builder.create( + builder.create(a, a), builder.create(b, b)); + Value zero = builder.create( + elementType, builder.getFloatAttr(elementType, 0)); + Value half = builder.create( + elementType, builder.getFloatAttr(elementType, 0.5)); + Value one = builder.create( + elementType, builder.getFloatAttr(elementType, 1)); + Value half_c = builder.create(half, c); + Value aa_p_bb_to_half_c = builder.create(aa_p_bb, half_c); + + Value neg_d = builder.create(d); + Value arg_lhs = builder.create(b, a); + Value neg_d_arg_lhs = builder.create(neg_d, arg_lhs); + Value e_to_neg_d_arg_lhs = builder.create(neg_d_arg_lhs); + + Value coeff = + builder.create(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + + Value ln_aa_p_bb = builder.create(aa_p_bb); + Value half_d = builder.create(half, d); + Value q = builder.create( + builder.create(c, arg_lhs), + builder.create(half_d, ln_aa_p_bb)); + + Value cos_q = builder.create(q); + Value sin_q = builder.create(q); + + Value lhs_eq_zero = + builder.create(arith::CmpFPredicate::OEQ, aa_p_bb, zero); + Value rhs_ge_zero = builder.create( + builder.create(arith::CmpFPredicate::OGE, c, zero), + builder.create(arith::CmpFPredicate::OEQ, d, zero)); + Value rhs_eq_zero = + builder.create(arith::CmpFPredicate::OEQ, c, zero); + Value complex_zero = builder.create(type, zero, zero); + Value complex_one = builder.create(type, one, zero); + Value complex_other = builder.create( + type, builder.create(coeff, cos_q), + builder.create(coeff, sin_q)); + + // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return builder.create( + builder.create(lhs_eq_zero, rhs_ge_zero), + builder.create(rhs_eq_zero, complex_one, complex_zero), + complex_other); +} + +struct PowOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = adaptor.getLhs().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value a = builder.create(elementType, adaptor.getLhs()); + Value b = builder.create(elementType, adaptor.getLhs()); + Value c = builder.create(elementType, adaptor.getRhs()); + Value d = builder.create(elementType, adaptor.getRhs()); + + rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); + return success(); + } +}; + +struct RsqrtOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = adaptor.getComplex().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value a = builder.create(elementType, adaptor.getComplex()); + Value b = builder.create(elementType, adaptor.getComplex()); + Value c = builder.create( + elementType, builder.getFloatAttr(elementType, -0.5)); + Value d = builder.create( + elementType, builder.getFloatAttr(elementType, 0)); + + rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)}); + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -931,7 +1030,9 @@ SinOpConversion, SqrtOpConversion, TanOpConversion, - TanhOpConversion + TanhOpConversion, + PowOpConversion, + RsqrtOpConversion >(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -676,4 +676,100 @@ // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex -// CHECK: return %[[RESULT]] : complex \ No newline at end of file +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func.func @complex_pow( +// CHECK-SAME: %[[VAL_0:.*]]: complex, +// CHECK-SAME: %[[VAL_1:.*]]: complex) -> complex { +func.func @complex_pow(%lhs: complex, + %rhs: complex) -> complex { + %pow = complex.pow %lhs, %rhs : complex + return %pow : complex +} +// CHECK: %[[VAL_2:.*]] = complex.re %[[VAL_0]] : complex +// CHECK: %[[VAL_3:.*]] = complex.im %[[VAL_0]] : complex +// CHECK: %[[VAL_4:.*]] = complex.re %[[VAL_1]] : complex +// CHECK: %[[VAL_5:.*]] = complex.im %[[VAL_1]] : complex +// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_2]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_7:.*]] = arith.mulf %[[VAL_3]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_8:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_10]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_13:.*]] = math.powf %[[VAL_8]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_14:.*]] = arith.negf %[[VAL_5]] : f32 +// CHECK: %[[VAL_15:.*]] = math.atan2 %[[VAL_3]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_17:.*]] = math.exp %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_13]], %[[VAL_17]] : f32 +// CHECK: %[[VAL_19:.*]] = math.log %[[VAL_8]] : f32 +// CHECK: %[[VAL_20:.*]] = arith.mulf %[[VAL_10]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_4]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_22:.*]] = arith.mulf %[[VAL_20]], %[[VAL_19]] : f32 +// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = math.cos %[[VAL_23]] : f32 +// CHECK: %[[VAL_25:.*]] = math.sin %[[VAL_23]] : f32 +// CHECK: %[[VAL_26:.*]] = arith.cmpf oeq, %[[VAL_8]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_27:.*]] = arith.cmpf oge, %[[VAL_4]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_28:.*]] = arith.cmpf oeq, %[[VAL_5]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 +// CHECK: %[[VAL_30:.*]] = arith.cmpf oeq, %[[VAL_4]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_31:.*]] = complex.create %[[VAL_9]], %[[VAL_9]] : complex +// CHECK: %[[VAL_32:.*]] = complex.create %[[VAL_11]], %[[VAL_9]] : complex +// CHECK: %[[VAL_33:.*]] = arith.mulf %[[VAL_18]], %[[VAL_24]] : f32 +// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_18]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_35:.*]] = complex.create %[[VAL_33]], %[[VAL_34]] : complex +// CHECK: %[[VAL_36:.*]] = arith.andi %[[VAL_26]], %[[VAL_29]] : i1 +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_30]], %[[VAL_32]], %[[VAL_31]] : complex +// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_35]] : complex +// CHECK: return %[[VAL_38]] : complex + +// ----- + +// CHECK-LABEL: func.func @complex_rsqrt( +// CHECK-SAME: %[[VAL_0:.*]]: complex) -> complex { +func.func @complex_rsqrt(%arg: complex) -> complex { + %rsqrt = complex.rsqrt %arg : complex + return %rsqrt : complex +} +// CHECK: %[[VAL_1:.*]] = complex.re %[[VAL_0]] : complex +// CHECK: %[[VAL_2:.*]] = complex.im %[[VAL_0]] : complex +// CHECK: %[[VAL_3:.*]] = arith.constant -5.000000e-01 : f32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_1]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_6:.*]] = arith.mulf %[[VAL_2]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_9:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK: %[[VAL_10:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = arith.mulf %[[VAL_9]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_12:.*]] = math.powf %[[VAL_7]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_13:.*]] = arith.negf %[[VAL_4]] : f32 +// CHECK: %[[VAL_14:.*]] = math.atan2 %[[VAL_2]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_15]] : f32 +// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_12]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = math.log %[[VAL_7]] : f32 +// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_9]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_20:.*]] = arith.mulf %[[VAL_3]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_18]] : f32 +// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: %[[VAL_23:.*]] = math.cos %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = math.sin %[[VAL_22]] : f32 +// CHECK: %[[VAL_25:.*]] = arith.cmpf oeq, %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_26:.*]] = arith.cmpf oge, %[[VAL_3]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_27:.*]] = arith.cmpf oeq, %[[VAL_4]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_28:.*]] = arith.andi %[[VAL_26]], %[[VAL_27]] : i1 +// CHECK: %[[VAL_29:.*]] = arith.cmpf oeq, %[[VAL_3]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_30:.*]] = complex.create %[[VAL_8]], %[[VAL_8]] : complex +// CHECK: %[[VAL_31:.*]] = complex.create %[[VAL_10]], %[[VAL_8]] : complex +// CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_17]], %[[VAL_23]] : f32 +// CHECK: %[[VAL_33:.*]] = arith.mulf %[[VAL_17]], %[[VAL_24]] : f32 +// CHECK: %[[VAL_34:.*]] = complex.create %[[VAL_32]], %[[VAL_33]] : complex +// CHECK: %[[VAL_35:.*]] = arith.andi %[[VAL_25]], %[[VAL_28]] : i1 +// CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_29]], %[[VAL_31]], %[[VAL_30]] : complex +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_34]] : complex +// CHECK: return %[[VAL_37]] : complex \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir --- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir +++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir @@ -38,6 +38,11 @@ func.return %tanh : complex } +func.func @rsqrt(%arg: complex) -> complex { + %sqrt = complex.rsqrt %arg : complex + func.return %sqrt : complex +} + // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...] func.func @test_binary(%input: tensor>, %func: (complex, complex) -> complex) { @@ -67,6 +72,10 @@ func.return %atan2 : complex } +func.func @pow(%lhs: complex, %rhs: complex) -> complex { + %pow = complex.pow %lhs, %rhs : complex + func.return %pow : complex +} func.func @entry() { // complex.sqrt test @@ -121,6 +130,30 @@ : (tensor>, (complex, complex) -> complex) -> () + // complex.pow test + %pow_test = arith.constant dense<[ + (0.0, 0.0), (0.0, 0.0), + // CHECK: 1 + // CHECK-NEXT: 0 + (0.0, 0.0), (1.0, 0.0), + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + (0.0, 0.0), (-1.0, 0.0), + // CHECK-NEXT: -nan + // CHECK-NEXT: -nan + (1.0, 1.0), (1.0, 1.0) + // CHECK-NEXT: 0.273 + // CHECK-NEXT: 0.583 + ]> : tensor<8xcomplex> + %pow_test_cast = tensor.cast %pow_test + : tensor<8xcomplex> to tensor> + + %pow_func = func.constant @pow : (complex, complex) + -> complex + call @test_binary(%pow_test_cast, %pow_func) + : (tensor>, (complex, complex) + -> complex) -> () + // complex.tanh test %tanh_test = arith.constant dense<[ (-1.0, -1.0), @@ -152,5 +185,36 @@ call @test_unary(%tanh_test_cast, %tanh_func) : (tensor>, (complex) -> complex) -> () + // complex.rsqrt test + %rsqrt_test = arith.constant dense<[ + (-1.0, -1.0), + // CHECK: 0.321 + // CHECK-NEXT: 0.776 + (-1.0, 1.0), + // CHECK-NEXT: 0.321 + // CHECK-NEXT: -0.776 + (0.0, 0.0), + // CHECK-NEXT: nan + // CHECK-NEXT: nan + (0.0, 1.0), + // CHECK-NEXT: 0.707 + // CHECK-NEXT: -0.707 + (1.0, -1.0), + // CHECK-NEXT: 0.776 + // CHECK-NEXT: 0.321 + (1.0, 0.0), + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + (1.0, 1.0) + // CHECK-NEXT: 0.776 + // CHECK-NEXT: -0.321 + ]> : tensor<7xcomplex> + %rsqrt_test_cast = tensor.cast %rsqrt_test + : tensor<7xcomplex> to tensor> + + %rsqrt_func = func.constant @rsqrt : (complex) -> complex + call @test_unary(%rsqrt_test_cast, %rsqrt_func) + : (tensor>, (complex) -> complex) -> () + func.return }