diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1,6 +1,17 @@ // RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\ // RUN: FileCheck %s --dump-input=always +// Run Complex2Standard and structural conversion of func ops in both orders +// with the same check prefix to check that they produce the same result. + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -convert-complex-to-standard -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck %s --dump-input=always --check-prefix=CHECK-FUNC + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" -convert-complex-to-standard \ +// RUN: | FileCheck %s --dump-input=always --check-prefix=CHECK-FUNC + // CHECK-LABEL: func @complex_abs // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_abs(%arg: complex) -> f32 { @@ -15,6 +26,18 @@ // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// CHECK-FUNC-LABEL: func @complex_abs( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> f32 +// CHECK-FUNC: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK-FUNC: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK-FUNC: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32 +// CHECK-FUNC: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK-FUNC: return %[[NORM]] : f32 + // ----- // CHECK-LABEL: func @complex_atan2 @@ -24,6 +47,14 @@ return %atan2 : complex } +// CHECK-FUNC-LABEL: func @complex_atan2( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[V0:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[V1:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func @complex_add @@ -41,6 +72,24 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_add( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK-FUNC: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK-FUNC: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-FUNC: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK-FUNC: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-FUNC: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func @complex_cos @@ -175,6 +224,14 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_div( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func @complex_eq @@ -192,6 +249,22 @@ // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 +// CHECK-FUNC-LABEL: func @complex_eq( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> i1 +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK-FUNC: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK-FUNC: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK-FUNC: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-FUNC-DAG: %[[REAL_EQUAL:.*]] = arith.cmpf oeq, %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-FUNC-DAG: %[[IMAG_EQUAL:.*]] = arith.cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK-FUNC: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 +// CHECK-FUNC: return %[[EQUAL]] : i1 + // ----- // CHECK-LABEL: func @complex_exp @@ -210,6 +283,22 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_exp( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-NEXT: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32 +// CHECK-FUNC-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32 +// CHECK-FUNC-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32 +// CHECK-FUNC-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32 +// CHECK-FUNC-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func.func @complex_expm1( @@ -233,6 +322,27 @@ // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex // CHECK: return %[[RES]] : complex +// CHECK-FUNC-LABEL: func @complex_expm1( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC-DAG: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32 +// CHECK-FUNC-DAG: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32 +// CHECK-FUNC-DAG: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32 +// CHECK-FUNC-DAG: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32 +// CHECK-FUNC-DAG: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32 +// CHECK-FUNC-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// XXX: If ComplexToStandard runs first, some ops are folded, so we leave some +// leeway for the matching... +// CHECK-FUNC: %[[REAL_M1:.*]] = arith.subf %[[RES_REAL:.*]], %[[ONE]] : f32 +// CHECK-FUNC: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[RES_IMAG:.*]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RES]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RES]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + + // ----- // CHECK-LABEL: func @complex_log @@ -254,6 +364,25 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_log( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK-FUNC: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK-FUNC: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK-FUNC: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK-FUNC: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 +// CHECK-FUNC: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func @complex_log1p @@ -280,6 +409,28 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_log1p( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC-DAG: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-FUNC-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-FUNC-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-FUNC-DAG: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 +// CHECK-FUNC-DAG: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32 +// CHECK-FUNC-DAG: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32 +// CHECK-FUNC-DAG: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK-FUNC-DAG: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32 +// CHECK-FUNC-DAG: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32 +// CHECK-FUNC-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32 +// CHECK-FUNC-DAG: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32 +// CHECK-FUNC: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 // ----- // CHECK-LABEL: func @complex_mul @@ -400,6 +551,14 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_mul( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func @complex_neg @@ -415,6 +574,19 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_neg( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[NEG_REAL:.*]] = arith.negf %[[REAL]] : f32 +// CHECK-FUNC-DAG: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func @complex_neq @@ -432,6 +604,21 @@ // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 // CHECK: return %[[NOT_EQUAL]] : i1 +// CHECK-FUNC-LABEL: func @complex_neq( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> i1 +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK-FUNC: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK-FUNC: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK-FUNC: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-FUNC-DAG: %[[REAL_NOT_EQUAL:.*]] = arith.cmpf une, %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-FUNC-DAG: %[[IMAG_NOT_EQUAL:.*]] = arith.cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK-FUNC: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 +// CHECK-FUNC: return %[[NOT_EQUAL]] : i1 // ----- // CHECK-LABEL: func @complex_sin @@ -455,6 +642,28 @@ // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] +// CHECK-FUNC-LABEL: func @complex_sin( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC-DAG: %[[REAL:.*]] = complex.re %[[ARG]] +// CHECK-FUNC-DAG: %[[IMAG:.*]] = complex.im %[[ARG]] +// CHECK-FUNC-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-FUNC-DAG: %[[EXP:.*]] = math.exp %[[IMAG]] : f32 +// XXX: If ComplexToStandard runs first, the order of the operands are inversed, +// so we leave some leeway for the matching... +// CHECK-FUNC-DAG: %[[HALF_EXP:.*]] = arith.mulf %{{.*}}, %{{.*}} +// CHECK-FUNC-DAG: %[[HALF_REXP:.*]] = arith.divf %[[HALF]], %[[EXP]] +// CHECK-FUNC-DAG: %[[SIN:.*]] = math.sin %[[REAL]] : f32 +// CHECK-FUNC-DAG: %[[COS:.*]] = math.cos %[[REAL]] : f32 +// CHECK-FUNC-DAG: %[[EXP_SUM:.*]] = arith.addf %[[HALF_EXP]], %[[HALF_REXP]] +// CHECK-FUNC-DAG: %[[RESULT_REAL:.*]] = arith.mulf %[[EXP_SUM]], %[[SIN]] +// CHECK-FUNC-DAG: %[[EXP_DIFF:.*]] = arith.subf %[[HALF_EXP]], %[[HALF_REXP]] +// CHECK-FUNC-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_DIFF]], %[[COS]] +// CHECK-FUNC-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 // ----- // CHECK-LABEL: func @complex_sign @@ -481,6 +690,30 @@ // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_sign( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC-DAG: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-FUNC-DAG: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32 +// CHECK-FUNC-DAG: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32 +// CHECK-FUNC-DAG: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1 +// CHECK-FUNC-DAG: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC-DAG: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32 +// CHECK-FUNC-DAG: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32 +// CHECK-FUNC-DAG: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK-FUNC-DAG: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK-FUNC-DAG: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32 +// CHECK-FUNC-DAG: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32 +// CHECK-FUNC-DAG: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex +// CHECK-FUNC-DAG: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func @complex_sub @@ -498,6 +731,24 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_sub( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK-FUNC: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK-FUNC: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-FUNC: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK-FUNC: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-FUNC: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func @complex_tan @@ -637,6 +888,10 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-FUNC-LABEL: func @complex_tan( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex // ----- // CHECK-LABEL: func @complex_tanh @@ -656,6 +911,11 @@ // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32 // CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex +// CHECK-FUNC-LABEL: func @complex_tanh( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func @complex_sqrt @@ -664,6 +924,11 @@ return %sqrt : complex } +// CHECK-FUNC-LABEL: func @complex_sqrt( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func @complex_conj @@ -678,6 +943,19 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex // CHECK: return %[[RESULT]] : complex + +// CHECK-FUNC-LABEL: func @complex_conj( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex +// CHECK-FUNC: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK-FUNC: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-FUNC: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32 +// CHECK-FUNC: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex +// CHECK-FUNC: %[[RESULT_RE:.*]] = complex.re %[[RESULT]] : complex +// CHECK-FUNC: %[[RESULT_IM:.*]] = complex.im %[[RESULT]] : complex +// CHECK-FUNC: return %[[RESULT_RE]], %[[RESULT_IM]] : f32, f32 + // ----- // CHECK-LABEL: func.func @complex_pow @@ -687,6 +965,14 @@ return %pow : complex } +// CHECK-FUNC-LABEL: func @complex_pow( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG2:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG3:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[RHS:.*]] = complex.create %[[ARG2]], %[[ARG3]] : complex +// CHECK-FUNC-DAG: %[[LHS:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func.func @complex_rsqrt @@ -695,6 +981,11 @@ return %rsqrt : complex } +// CHECK-FUNC-LABEL: func @complex_rsqrt( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> (f32, f32) +// CHECK-FUNC-DAG: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex + // ----- // CHECK-LABEL: func.func @complex_angle @@ -707,3 +998,8 @@ // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] : f32 // CHECK: return %[[RESULT]] : f32 + +// CHECK-FUNC-LABEL: func @complex_angle( +// CHECK-FUNC-SAME: %[[ARG0:[^:]*]]: f32, +// CHECK-FUNC-SAME: %[[ARG1:[^:]*]]: f32) -> f32 +// CHECK-FUNC: %[[ARG:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -6,6 +6,7 @@ EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC + MLIRComplexDialect MLIRFuncDialect MLIRIR MLIRTestDialect diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -9,6 +9,7 @@ #include "OneToNTypeConversion.h" #include "OneToNTypeConversionFunc.h" #include "TestDialect.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -208,14 +209,34 @@ return builder.create<::test::MakeTupleOp>(loc, resultType, elements); } +static std::optional buildCreateComplexOp(OpBuilder &builder, + ComplexType resultType, + ValueRange inputs, + Location loc) { + assert(inputs.size() == 2 && "expected two inputs to create complex number"); + Value re = inputs[0]; + Value im = inputs[1]; + return builder.create(loc, resultType, re, im); +} + +static std::optional> +buildComplexReImOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + auto re = builder.create(loc, input); + auto im = builder.create(loc, input); + SmallVector resultValues = {re, im}; + return resultValues; +} + void TestOneToNTypeConversionPass::runOnOperation() { ModuleOp module = getOperation(); auto *context = &getContext(); // Assemble type converter. OneToNTypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); + + // Type conversion and materializations for tuples. typeConverter.addConversion( [](TupleType tupleType, SmallVectorImpl &types) { tupleType.getFlattenedTypes(types); @@ -226,6 +247,18 @@ typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildGetTupleElementOps); + // Type conversion and materializations for complex. + typeConverter.addConversion( + [](ComplexType complexType, SmallVectorImpl &types) { + Type elementType = complexType.getElementType(); + types.append(2, elementType); // re and im + return success(); + }); + + typeConverter.addArgumentMaterialization(buildCreateComplexOp); + typeConverter.addSourceMaterialization(buildCreateComplexOp); + typeConverter.addTargetMaterialization(buildComplexReImOps); + // Assemble patterns. RewritePatternSet patterns(context); if (convertTupleOps)