diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2430,10 +2430,11 @@ def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; let description = [{ - Cast from a value interpreted as signed integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. Only scalars are currently - supported. + Cast from a value interpreted as signed or vector of signed + integers to the corresponding floating-point scalar or vector value. + If the value cannot be exactly represented, it is + rounded using the default rounding mode. Scalars and vector types + are currently supported. }]; let extraClassDeclaration = [{ @@ -3111,10 +3112,11 @@ def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from unsigned integer type to floating-point"; let description = [{ - Cast from a value interpreted as unsigned integer to the corresponding - floating-point value. If the value cannot be exactly represented, it is - rounded using the default rounding mode. Only scalars are currently - supported. + Cast from a value interpreted as unsigned integer or vector of unsigned + integers to the corresponding scalar or vector floating-point value. + If the value cannot be exactly represented, it is + rounded using the default rounding mode. Scalars and vector types are + currently supported. }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -217,6 +217,25 @@ return success(folded); } +// //===----------------------------------------------------------------------===// +// // Common cast compatibility check for vector types +// //===----------------------------------------------------------------------===// + +/// This method checks for cast compatibility of vector types. +/// If 'a' and 'b' are vector types, and they are cast compatible (based on +/// shape and element type). For element type a recursive call to the function +/// that called us is used, passed as parameter 'areElementsCastCompatible', +/// since it handles the single element cast compatibility. +static bool areVectorCastSimpleCompatible( + Type a, Type b, function_ref areElementsCastCompatible) { + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areElementsCastCompatible(va.getElementType(), + vb.getElementType()); + return false; +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -1848,11 +1867,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1860,7 +1875,9 @@ //===----------------------------------------------------------------------===// bool FPToSIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1868,7 +1885,9 @@ //===----------------------------------------------------------------------===// bool FPToUIOp::areCastCompatible(Type a, Type b) { - return a.isa() && b.isSignlessInteger(); + if (a.isa() && b.isSignlessInteger()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -1879,11 +1898,7 @@ if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); - if (auto va = a.dyn_cast()) - if (auto vb = b.dyn_cast()) - return va.getShape().equals(vb.getShape()) && - areCastCompatible(va.getElementType(), vb.getElementType()); - return false; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2323,7 +2338,9 @@ // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// @@ -2403,7 +2420,9 @@ // uitofp is applicable from integer types to float types. bool UIToFPOp::areCastCompatible(Type a, Type b) { - return a.isSignlessInteger() && b.isa(); + if (a.isSignlessInteger() && b.isa()) + return true; + return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -594,6 +594,24 @@ return } +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @sitofp_vector +func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = sitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = sitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = sitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = sitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = sitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = sitofp %arg2: vector<2xi64> to vector<2xf64> + return +} + // Checking conversion of unsigned integer types to floating point. // CHECK-LABEL: @uitofp func @uitofp(%arg0 : i32, %arg1 : i64) { @@ -646,6 +664,24 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptosi_vector +func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptosi %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptosi %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptosi %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptosi %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptosi %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptosi %arg2: vector<2xf64> to vector<2xi64> + return +} + // Checking conversion of floating point to integer types. // CHECK-LABEL: @fptoui func @fptoui(%arg0 : f32, %arg1 : f64) { @@ -660,6 +696,41 @@ return } +// Checking conversion of floating point vectors to integer vector types. +// CHECK-LABEL: @fptoui_vector +func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32> + %0 = fptoui %arg0: vector<2xf16> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64> + %1 = fptoui %arg0: vector<2xf16> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32> + %2 = fptoui %arg1: vector<2xf32> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64> + %3 = fptoui %arg1: vector<2xf32> to vector<2xi64> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32> + %4 = fptoui %arg2: vector<2xf64> to vector<2xi32> +// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64> + %5 = fptoui %arg2: vector<2xf64> to vector<2xi64> + return +} + +// Checking conversion of integer vectors to floating point vector types. +// CHECK-LABEL: @uitofp_vector +func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) { +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float> + %0 = uitofp %arg0: vector<2xi16> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double> + %1 = uitofp %arg0: vector<2xi16> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float> + %2 = uitofp %arg1: vector<2xi32> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double> + %3 = uitofp %arg1: vector<2xi32> to vector<2xf64> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float> + %4 = uitofp %arg2: vector<2xi64> to vector<2xf32> +// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double> + %5 = uitofp %arg2: vector<2xi64> to vector<2xf64> + return +} // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc