diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -24,6 +24,8 @@ ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[0]; + if (auto createOp = dyn_cast_or_null(getOperand().getDefiningOp())) + return createOp.getOperand(0); return {}; } @@ -32,5 +34,7 @@ ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[1]; + if (auto createOp = dyn_cast_or_null(getOperand().getDefiningOp())) + return createOp.getOperand(1); return {}; } diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -1,19 +1,24 @@ // RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s -// CHECK-LABEL: func @complex_numbers -// CHECK-NEXT: %[[REAL0:.*]] = constant 1.200000e+00 : f32 -// CHECK-NEXT: %[[IMAG0:.*]] = constant 3.400000e+00 : f32 +// CHECK-LABEL: func @complex_create +// CHECK-SAME: (%[[REAL0:.*]]: f32, %[[IMAG0:.*]]: f32) // CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)> -func @complex_numbers() { - %real0 = constant 1.2 : f32 - %imag0 = constant 3.4 : f32 - %cplx2 = complex.create %real0, %imag0 : complex - %real1 = complex.re%cplx2 : complex - %imag1 = complex.im %cplx2 : complex +func @complex_create(%real: f32, %imag: f32) -> complex { + %cplx2 = complex.create %real, %imag : complex + return %cplx2 : complex +} + +// CHECK-LABEL: func @complex_extract +// CHECK-SAME: (%[[CPLX:.*]]: complex) +// CHECK-NEXT: %[[CAST0:.*]] = llvm.mlir.cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[CAST1:.*]] = llvm.mlir.cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST1]][1] : !llvm.struct<(f32, f32)> +func @complex_extract(%cplx: complex) { + %real1 = complex.re %cplx : complex + %imag1 = complex.im %cplx : complex return } diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -9,6 +9,17 @@ return %1 : f32 } +// CHECK-LABEL: func @real_of_create_op( +func @real_of_create_op() -> f32 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK-NEXT: return %[[CST]] : f32 + %real = constant 1.0 : f32 + %imag = constant 0.0 : f32 + %complex = complex.create %real, %imag : complex + %1 = complex.re %complex : complex + return %1 : f32 +} + // CHECK-LABEL: func @imag_of_const( func @imag_of_const() -> f32 { // CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 @@ -17,3 +28,14 @@ %1 = complex.im %complex : complex return %1 : f32 } + +// CHECK-LABEL: func @imag_of_create_op( +func @imag_of_create_op() -> f32 { + // CHECK: %[[CST:.*]] = constant 0.000000e+00 : f32 + // CHECK-NEXT: return %[[CST]] : f32 + %real = constant 1.0 : f32 + %imag = constant 0.0 : f32 + %complex = complex.create %real, %imag : complex + %1 = complex.im %complex : complex + return %1 : f32 +}