diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2443,10 +2443,10 @@ def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; let description = [{ - Cast from a value interpreted as signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. Only scalars are currently - supported. + Cast from a value interpreted as signed or vector of signed integers to the + corresponding floating-point scalar or vector value. If the value cannot be + exactly represented, it is rounded using the default rounding mode. Scalars + and vector types are currently supported. }]; let extraClassDeclaration = [{ @@ -3124,10 +3124,10 @@ def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from unsigned integer type to floating-point"; let description = [{ - Cast from a value interpreted as unsigned integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. Only scalars are currently - supported. + Cast from a value interpreted as unsigned integer or vector of unsigned + integers to the corresponding scalar or vector floating-point value. If the + value cannot be exactly represented, it is rounded using the default + rounding mode. Scalars and vector types are currently supported. }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -217,6 +217,26 @@ return success(folded); } +//===----------------------------------------------------------------------===// +// Common cast compatibility check for vector types. +//===----------------------------------------------------------------------===// + +/// This method checks for cast compatibility of vector types. +/// If 'a' and 'b' are vector types, and they are cast compatible, +/// it calls the 'areElementsCastCompatible' function to check for +/// element cast compatibility. +/// Returns 'true' if the vector types are cast compatible, and 'false' +/// otherwise. +static bool areVectorCastSimpleCompatible( + Type a, Type b, function_ref areElementsCastCompatible) { + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areElementsCastCompatible(va.getElementType(), + vb.getElementType()); + return false; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -1816,11 +1836,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1828,7 +1844,9 @@ //===----------------------------------------------------------------------===// bool FPToSIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1836,7 +1854,9 @@ //===----------------------------------------------------------------------===// bool FPToUIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1847,11 +1867,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2291,7 +2307,9 @@ // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2371,7 +2389,9 @@ // uitofp is applicable from integer types to float types. bool UIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -594,6 +594,24 @@ return } +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @sitofp_vector +func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = sitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = sitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = sitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = sitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = sitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = sitofp %arg2: vector<2xi64> to vector<2xf64> + return +} + // Checking conversion of unsigned integer types to floating point. // CHECK-LABEL: @uitofp func @uitofp(%arg0 : i32, %arg1 : i64) { @@ -646,6 +664,24 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptosi_vector +func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptosi %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptosi %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptosi %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptosi %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptosi %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptosi %arg2: vector<2xf64> to vector<2xi64> + return +} + // Checking conversion of floating point to integer types. // CHECK-LABEL: @fptoui func @fptoui(%arg0 : f32, %arg1 : f64) { @@ -660,6 +696,41 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptoui_vector +func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptoui %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptoui %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptoui %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptoui %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptoui %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptoui %arg2: vector<2xf64> to vector<2xi64> + return +} + +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @uitofp_vector +func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = uitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = uitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = uitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = uitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = uitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = uitofp %arg2: vector<2xi64> to vector<2xf64> + return +} // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc