diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -217,6 +217,21 @@ return success(folded); } +// //===----------------------------------------------------------------------===// +// // Common cast compatibility check for vector types +// //===----------------------------------------------------------------------===// + +// Type compatibility for vector casts. +static bool areVectorCastSimpleCompatible( + Type a, Type b, function_ref areElementsCastCompatible) { + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areElementsCastCompatible(va.getElementType(), + vb.getElementType()); + return false; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -1848,11 +1863,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1860,7 +1871,9 @@ //===----------------------------------------------------------------------===// bool FPToSIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1868,7 +1881,9 @@ //===----------------------------------------------------------------------===// bool FPToUIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1879,11 +1894,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2323,7 +2334,9 @@ // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2403,7 +2416,9 @@ // uitofp is applicable from integer types to float types. bool UIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -594,6 +594,24 @@ return } +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @sitofp_vector +func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = sitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = sitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = sitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = sitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = sitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = sitofp %arg2: vector<2xi64> to vector<2xf64> + return +} + // Checking conversion of unsigned integer types to floating point. // CHECK-LABEL: @uitofp func @uitofp(%arg0 : i32, %arg1 : i64) { @@ -646,6 +664,24 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptosi_vector +func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptosi %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptosi %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptosi %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptosi %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptosi %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptosi %arg2: vector<2xf64> to vector<2xi64> + return +} + // Checking conversion of floating point to integer types. // CHECK-LABEL: @fptoui func @fptoui(%arg0 : f32, %arg1 : f64) { @@ -660,6 +696,41 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptoui_vector +func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptoui %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptoui %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptoui %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptoui %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptoui %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptoui %arg2: vector<2xf64> to vector<2xi64> + return +} + +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @uitofp_vector +func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = uitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = uitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = uitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = uitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = uitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = uitofp %arg2: vector<2xi64> to vector<2xf64> + return +} // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc