diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td @@ -486,6 +486,39 @@ // ----- +def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> { + let summary = [{ + Round x to integral value (using round to nearest even rounding mode) in + floating-point format. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + rint-op ::= ssa-id `=` `spirv.CL.rint` ssa-use `:` + float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %0 = spirv.CL.rint %0 : f32 + %1 = spirv.CL.rint %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPIRV_CLRsqrtOp : SPIRV_CLUnaryArithmeticOp<"rsqrt", 56, SPIRV_Float> { let summary = "Compute inverse square root of x."; @@ -688,6 +721,8 @@ }]; } +// ----- + def SPIRV_CLSMinOp : SPIRV_CLBinaryArithmeticOp<"s_min", 158, SPIRV_Integer> { let summary = "Return minimum of two signed integer operands"; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td @@ -456,10 +456,13 @@ // ----- def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> { - let summary = "Rounds to the whole number"; + let summary = "Rounds to the nearest whole number"; let description = [{ - Result is the value equal to the nearest whole number. + Result is the value equal to the nearest whole number to x. The fraction + 0.5 will round in a direction chosen by the implementation, presumably + the direction that is fastest. This includes the possibility that + Round x is the same value as RoundEven x for all values of x. The operand x must be a scalar or vector whose component type is floating-point. @@ -471,7 +474,7 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - floor-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:` + round-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:` float-scalar-vector-type ``` #### Example: @@ -485,6 +488,38 @@ // ----- +def SPIRV_GLRoundEvenOp: SPIRV_GLUnaryArithmeticOp<"RoundEven", 2, SPIRV_Float> { + let summary = "Rounds to the nearest even whole number"; + + let description = [{ + Result is the value equal to the nearest whole number to x. A fractional + part of 0.5 will round toward the nearest even whole number. (Both 3.5 and + 4.5 for x will be 4.0.) + + The operand x must be a scalar or vector whose component type is + floating-point. + + Result Type and the type of x must be the same type. Results are computed + per component. + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + round-even-op ::= ssa-id `=` `spirv.GL.RoundEven` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spirv.GL.RoundEven %0 : f32 + %3 = spirv.GL.RoundEven %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_Float> { let summary = "Reciprocal of sqrt(operand)"; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -295,6 +295,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -312,6 +313,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -20,20 +20,22 @@ // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}} // CHECK: spirv.GL.Log %[[ADDONE]] %4 = math.log1p %arg0 : f32 + // CHECK: spirv.GL.RoundEven %{{.*}}: f32 + %5 = math.roundeven %arg0 : f32 // CHECK: spirv.GL.InverseSqrt %{{.*}}: f32 - %5 = math.rsqrt %arg0 : f32 + %6 = math.rsqrt %arg0 : f32 // CHECK: spirv.GL.Sqrt %{{.*}}: f32 - %6 = math.sqrt %arg0 : f32 + %7 = math.sqrt %arg0 : f32 // CHECK: spirv.GL.Tanh %{{.*}}: f32 - %7 = math.tanh %arg0 : f32 + %8 = math.tanh %arg0 : f32 // CHECK: spirv.GL.Sin %{{.*}}: f32 - %8 = math.sin %arg0 : f32 + %9 = math.sin %arg0 : f32 // CHECK: spirv.GL.FAbs %{{.*}}: f32 - %9 = math.absf %arg0 : f32 + %10 = math.absf %arg0 : f32 // CHECK: spirv.GL.Ceil %{{.*}}: f32 - %10 = math.ceil %arg0 : f32 + %11 = math.ceil %arg0 : f32 // CHECK: spirv.GL.Floor %{{.*}}: f32 - %11 = math.floor %arg0 : f32 + %12 = math.floor %arg0 : f32 return } @@ -53,14 +55,16 @@ // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}} // CHECK: spirv.GL.Log %[[ADDONE]] %4 = math.log1p %arg0 : vector<3xf32> + // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32> + %5 = math.roundeven %arg0 : vector<3xf32> // CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32> - %5 = math.rsqrt %arg0 : vector<3xf32> + %6 = math.rsqrt %arg0 : vector<3xf32> // CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32> - %6 = math.sqrt %arg0 : vector<3xf32> + %7 = math.sqrt %arg0 : vector<3xf32> // CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32> - %7 = math.tanh %arg0 : vector<3xf32> + %8 = math.tanh %arg0 : vector<3xf32> // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32> - %8 = math.sin %arg0 : vector<3xf32> + %9 = math.sin %arg0 : vector<3xf32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -288,6 +288,22 @@ return } +//===----------------------------------------------------------------------===// +// spirv.GL.RoundEven +//===----------------------------------------------------------------------===// + +func.func @round_even(%arg0 : f32) -> () { + // CHECK: spirv.GL.RoundEven {{%.*}} : f32 + %2 = spirv.GL.RoundEven %arg0 : f32 + return +} + +func.func @round_even_vec(%arg0 : vector<3xf16>) -> () { + // CHECK: spirv.GL.RoundEven {{%.*}} : vector<3xf16> + %2 = spirv.GL.RoundEven %arg0 : vector<3xf16> + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -227,3 +227,23 @@ %4 = spirv.CL.u_min %arg0, %arg1 : i32 return } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.CL.rint +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func.func @rint( +func.func @rint(%arg0 : f32) -> () { + // CHECK: spirv.CL.rint {{%.*}} : f32 + %0 = spirv.CL.rint %arg0 : f32 + return +} + +// CHECK-LABEL: func.func @rintvec( +func.func @rintvec(%arg0 : vector<3xf16>) -> () { + // CHECK: spirv.CL.rint {{%.*}} : vector<3xf16> + %0 = spirv.CL.rint %arg0 : vector<3xf16> + return +}