diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -122,6 +122,8 @@ %3 = spv.ConvertFToS %2 : vector<3xf32> to vector<3xi32> ``` }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -155,6 +157,8 @@ %3 = spv.ConvertFToU %2 : vector<3xf32> to vector<3xi32> ``` }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -186,6 +190,8 @@ %3 = spv.ConvertSToF %2 : vector<3xi32> to vector<3xf32> ``` }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -217,6 +223,8 @@ %3 = spv.ConvertUToF %2 : vector<3xi32> to vector<3xf32> ``` }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -305,7 +305,12 @@ } static LogicalResult verifyCastOp(Operation *op, - bool requireSameBitWidth = true) { + bool requireSameBitWidth = true, + bool skipBitWidthCheck = false) { + // Some CastOps have no limit on bit widths for result and operand type. + if (skipBitWidthCheck) + return success(); + Type operandType = op->getOperand(0).getType(); Type resultType = op->getResult(0).getType(); diff --git a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir @@ -20,21 +20,41 @@ %0 = spv.ConvertFToS %arg0 : f32 to i32 spv.ReturnValue %0 : i32 } + spv.func @convert_f64_to_s32(%arg0 : f64) -> i32 "None" { + // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : f64 to i32 + %0 = spv.ConvertFToS %arg0 : f64 to i32 + spv.ReturnValue %0 : i32 + } spv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" { // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f32 to i32 %0 = spv.ConvertFToU %arg0 : f32 to i32 spv.ReturnValue %0 : i32 } + spv.func @convert_f64_to_u32(%arg0 : f64) -> i32 "None" { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f64 to i32 + %0 = spv.ConvertFToU %arg0 : f64 to i32 + spv.ReturnValue %0 : i32 + } spv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i32 to f32 %0 = spv.ConvertSToF %arg0 : i32 to f32 spv.ReturnValue %0 : f32 } + spv.func @convert_s64_to_f32(%arg0 : i64) -> f32 "None" { + // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i64 to f32 + %0 = spv.ConvertSToF %arg0 : i64 to f32 + spv.ReturnValue %0 : f32 + } spv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i32 to f32 %0 = spv.ConvertUToF %arg0 : i32 to f32 spv.ReturnValue %0 : f32 } + spv.func @convert_u64_to_f32(%arg0 : i64) -> f32 "None" { + // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i64 to f32 + %0 = spv.ConvertUToF %arg0 : i64 to f32 + spv.ReturnValue %0 : f32 + } spv.func @f_convert(%arg0 : f32) -> f64 "None" { // CHECK: {{%.*}} = spv.FConvert {{%.*}} : f32 to f64 %0 = spv.FConvert %arg0 : f32 to f64 diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -335,6 +335,22 @@ // ----- +func @convert_f64_to_s32_scalar(%arg0 : f64) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : f64 to i32 + %0 = spv.ConvertFToS %arg0 : f64 to i32 + spv.ReturnValue %0 : i32 +} + +// ----- + +func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> { + // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : vector<3xf32> to vector<3xi32> + %0 = spv.ConvertFToS %arg0 : vector<3xf32> to vector<3xi32> + spv.ReturnValue %0 : vector<3xi32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ConvertFToU //===----------------------------------------------------------------------===// @@ -347,6 +363,14 @@ // ----- +func @convert_f64_to_u32_scalar(%arg0 : f64) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f64 to i32 + %0 = spv.ConvertFToU %arg0 : f64 to i32 + spv.ReturnValue %0 : i32 +} + +// ----- + func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> { // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : vector<3xf32> to vector<3xi32> %0 = spv.ConvertFToU %arg0 : vector<3xf32> to vector<3xi32> @@ -363,14 +387,6 @@ // ----- -func @convert_f_to_u_scalar_invalid(%arg0 : f16) -> i32 { - // expected-error @+1 {{expected the same bit widths for operand type and result type, but provided 'f16' and 'i32'}} - %0 = spv.ConvertFToU %arg0 : f16 to i32 - spv.ReturnValue %0 : i32 -} - -// ----- - //===----------------------------------------------------------------------===// // spv.ConvertSToF //===----------------------------------------------------------------------===// @@ -383,6 +399,22 @@ // ----- +func @convert_s64_to_f32_scalar(%arg0 : i64) -> f32 { + // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i64 to f32 + %0 = spv.ConvertSToF %arg0 : i64 to f32 + spv.ReturnValue %0 : f32 +} + +// ----- + +func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> { + // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : vector<3xi32> to vector<3xf32> + %0 = spv.ConvertSToF %arg0 : vector<3xi32> to vector<3xf32> + spv.ReturnValue %0 : vector<3xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ConvertUToF //===----------------------------------------------------------------------===// @@ -395,6 +427,22 @@ // ----- +func @convert_u64_to_f32_scalar(%arg0 : i64) -> f32 { + // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i64 to f32 + %0 = spv.ConvertUToF %arg0 : i64 to f32 + spv.ReturnValue %0 : f32 +} + +// ----- + +func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> { + // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : vector<3xi32> to vector<3xf32> + %0 = spv.ConvertUToF %arg0 : vector<3xi32> to vector<3xf32> + spv.ReturnValue %0 : vector<3xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.FConvert //===----------------------------------------------------------------------===//