diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -286,12 +286,13 @@ // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); - populateStdToLLVMConversionPatterns(converter, patterns); populateComplexToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - target.addLegalOp(); - if (failed(applyFullConversion(module, target, std::move(patterns)))) + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1375,6 +1375,17 @@ return success(); } + // Complex types are compatible with the two-element structs. + if (auto complexType = type.dyn_cast()) { + auto structType = llvmType.dyn_cast(); + if (!structType || structType.getBody().size() != 2 || + structType.getBody()[0] != structType.getBody()[1] || + structType.getBody()[0] != complexType.getElementType()) + return op->emitOpError("expected 'complex' to map to two-element struct " + "with identical element types"); + return success(); + } + // Everything else is not supported. return op->emitError("unsupported cast"); } diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -1,14 +1,13 @@ -// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-complex-to-llvm | FileCheck %s -// CHECK-LABEL: llvm.func @complex_numbers() -// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32 -// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32 +// CHECK-LABEL: func @complex_numbers +// CHECK-NEXT: %[[REAL0:.*]] = constant 1.200000e+00 : f32 +// CHECK-NEXT: %[[IMAG0:.*]] = constant 3.400000e+00 : f32 // CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: llvm.return func @complex_numbers() { %real0 = constant 1.2 : f32 %imag0 = constant 3.4 : f32 @@ -18,9 +17,7 @@ return } -// ----- - -// CHECK-LABEL: llvm.func @complex_addition() +// CHECK-LABEL: func @complex_addition // CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> @@ -41,9 +38,7 @@ return } -// ----- - -// CHECK-LABEL: llvm.func @complex_substraction() +// CHECK-LABEL: func @complex_substraction // CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> // CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> @@ -64,18 +59,19 @@ return } -// ----- - -// CHECK-LABEL: llvm.func @complex_div -// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] +// CHECK-LABEL: func @complex_div +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func @complex_div(%lhs: complex, %rhs: complex) -> complex { %div = complex.div %lhs, %rhs : complex return %div : complex } -// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] -// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] -// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] -// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] // CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] @@ -95,20 +91,23 @@ // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] // CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] : f32 // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] -// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] - -// ----- +// +// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex -// CHECK-LABEL: llvm.func @complex_mul -// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] +// CHECK-LABEL: func @complex_mul +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func @complex_mul(%lhs: complex, %rhs: complex) -> complex { %mul = complex.mul %lhs, %rhs : complex return %mul : complex } -// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] -// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] -// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] -// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK: %[[CASTED_LHS:.*]] = llvm.mlir.cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK: %[[CASTED_RHS:.*]] = llvm.mlir.cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] // CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] // CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] : f32 @@ -121,21 +120,22 @@ // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] -// CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] -// ----- +// CHECK: %[[CASTED_RESULT:.*]] = llvm.mlir.cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex -// CHECK-LABEL: llvm.func @complex_abs -// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) +// CHECK-LABEL: func @complex_abs +// CHECK-SAME: %[[ARG:.*]]: complex func @complex_abs(%arg: complex) -> f32 { %abs = complex.abs %arg: complex return %abs : f32 } -// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]] -// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]] +// CHECK: %[[CASTED_ARG:.*]] = llvm.mlir.cast %[[ARG]] : complex to ![[C_TY:.*>]] +// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[CASTED_ARG]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[CASTED_ARG]][1] : ![[C_TY]] // CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] : f32 // CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] : f32 // CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] : f32 // CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32 -// CHECK: llvm.return %[[NORM]] : f32 +// CHECK: return %[[NORM]] : f32 diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir copy from mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir copy to mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/full-conversion.mlir @@ -1,70 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -convert-complex-to-llvm | FileCheck %s - -// CHECK-LABEL: llvm.func @complex_numbers() -// CHECK-NEXT: %[[REAL0:.*]] = llvm.mlir.constant(1.200000e+00 : f32) : f32 -// CHECK-NEXT: %[[IMAG0:.*]] = llvm.mlir.constant(3.400000e+00 : f32) : f32 -// CHECK-NEXT: %[[CPLX0:.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[CPLX1:.*]] = llvm.insertvalue %[[REAL0]], %[[CPLX0]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[CPLX2:.*]] = llvm.insertvalue %[[IMAG0]], %[[CPLX1]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[REAL1:.*]] = llvm.extractvalue %[[CPLX2:.*]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[IMAG1:.*]] = llvm.extractvalue %[[CPLX2:.*]][1] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: llvm.return -func @complex_numbers() { - %real0 = constant 1.2 : f32 - %imag0 = constant 3.4 : f32 - %cplx2 = complex.create %real0, %imag0 : complex - %real1 = complex.re%cplx2 : complex - %imag1 = complex.im %cplx2 : complex - return -} - -// ----- - -// CHECK-LABEL: llvm.func @complex_addition() -// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> -// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] : f64 -// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] : f64 -// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> -// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> -func @complex_addition() { - %a_re = constant 1.2 : f64 - %a_im = constant 3.4 : f64 - %a = complex.create %a_re, %a_im : complex - %b_re = constant 5.6 : f64 - %b_im = constant 7.8 : f64 - %b = complex.create %b_re, %b_im : complex - %c = complex.add %a, %b : complex - return -} - -// ----- - -// CHECK-LABEL: llvm.func @complex_substraction() -// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> -// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> -// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] : f64 -// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] : f64 -// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> -// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> -func @complex_substraction() { - %a_re = constant 1.2 : f64 - %a_im = constant 3.4 : f64 - %a = complex.create %a_re, %a_im : complex - %b_re = constant 5.6 : f64 - %b_im = constant 7.8 : f64 - %b = complex.create %b_re, %b_im : complex - %c = complex.sub %a, %b : complex - return -} - -// ----- +// RUN: mlir-opt %s -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s // CHECK-LABEL: llvm.func @complex_div // CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] @@ -97,8 +31,6 @@ // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] // CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] -// ----- - // CHECK-LABEL: llvm.func @complex_mul // CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*>]], %[[RHS:.*]]: ![[C_TY]]) -> ![[C_TY]] func @complex_mul(%lhs: complex, %rhs: complex) -> complex { @@ -123,8 +55,6 @@ // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] // CHECK: llvm.return %[[RESULT_2]] : ![[C_TY]] -// ----- - // CHECK-LABEL: llvm.func @complex_abs // CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]]) func @complex_abs(%arg: complex) -> f32 { diff --git a/mlir/test/Dialect/LLVMIR/dialect-cast.mlir b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir --- a/mlir/test/Dialect/LLVMIR/dialect-cast.mlir +++ b/mlir/test/Dialect/LLVMIR/dialect-cast.mlir @@ -222,3 +222,31 @@ // expected-error@+1 {{expected second element of a memref descriptor to be an !llvm.ptr}} llvm.mlir.cast %0 : memref<*xf32> to !llvm.struct<(i64, f32)> } + +// ----- + +func @mlir_dialect_cast_complex_non_struct(%0: complex) { + // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}} + llvm.mlir.cast %0 : complex to f32 +} + +// ----- + +func @mlir_dialect_cast_complex_bad_size(%0: complex) { + // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}} + llvm.mlir.cast %0 : complex to !llvm.struct<(f32, f32, f32)> +} + +// ----- + +func @mlir_dialect_cast_complex_mismatching_type_struct(%0: complex) { + // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}} + llvm.mlir.cast %0 : complex to !llvm.struct<(f32, f64)> +} + +// ----- + +func @mlir_dialect_cast_complex_mismatching_element(%0: complex) { + // expected-error@+1 {{expected 'complex' to map to two-element struct with identical element types}} + llvm.mlir.cast %0 : complex to !llvm.struct<(f64, f64)> +}