diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3075,6 +3075,7 @@ SPV_AnyCooperativeMatrix, SPV_AnyMatrix ]>; +def SPV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; class SPV_CoopMatrixOfType allowedTypes> : diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td @@ -77,6 +77,28 @@ list traits = []> : SPV_GLSLBinaryOp; +// Base class for GLSL ternary ops. +class SPV_GLSLTernaryArithmeticOp traits = []> : + SPV_GLSLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$x, + SPV_ScalarOrVectorOf:$y, + SPV_ScalarOrVectorOf:$z + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + + let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; +} + // ----- def SPV_GLSLFAbsOp : SPV_GLSLUnaryArithmeticOp<"FAbs", 4, SPV_Float> { @@ -862,4 +884,92 @@ }]; } +// ----- + +def SPV_GLSLFClampOp : SPV_GLSLTernaryArithmeticOp<"FClamp", 43, SPV_Float> { + let summary = "Clamp x between min and max values."; + + let description = [{ + Result is min(max(x, minVal), maxVal). The resulting value is undefined if + minVal > maxVal. The semantics used by min() and max() are those of FMin and + FMax. + + The operands must all be a scalar or vector whose component type is + floating-point. + + Result Type and the type of all operands must be the same type. Results are + computed per component. + + + ``` + fclamp-op ::= ssa-id `=` `spv.GLSL.FClamp` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.GLSL.FClamp %x, %min, %max : f32 + %3 = spv.GLSL.FClamp %x, %min, %max : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_GLSLUClampOp : SPV_GLSLTernaryArithmeticOp<"UClamp", 44, SPV_SignlessOrUnsignedInt> { + let summary = "Clamp x between min and max values."; + + let description = [{ + Result is min(max(x, minVal), maxVal), where x, minVal and maxVal are + interpreted as unsigned integers. The resulting value is undefined if + minVal > maxVal. + + Result Type and the type of the operands must both be integer scalar or + integer vector types. Result Type and operand types must have the same number + of components with the same component width. Results are computed per + component. + + + ``` + uclamp-op ::= ssa-id `=` `spv.GLSL.UClamp` ssa-use, ssa-use, ssa-use `:` + unsgined-signless-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.GLSL.UClamp %x, %min, %max : i32 + %3 = spv.GLSL.UClamp %x, %min, %max : vector<3xui16> + ``` + }]; +} + +// ----- + +def SPV_GLSLSClampOp : SPV_GLSLTernaryArithmeticOp<"SClamp", 45, SPV_SignedInt> { + let summary = "Clamp x between min and max values."; + + let description = [{ + Result is min(max(x, minVal), maxVal), where x, minVal and maxVal are + interpreted as signed integers. The resulting value is undefined if + minVal > maxVal. + + Result Type and the type of the operands must both be integer scalar or + integer vector types. Result Type and operand types must have the same number + of components with the same component width. Results are computed per + component. + + + ``` + uclamp-op ::= ssa-id `=` `spv.GLSL.UClamp` ssa-use, ssa-use, ssa-use `:` + sgined-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.GLSL.SClamp %x, %min, %max : si32 + %3 = spv.GLSL.SClamp %x, %min, %max : vector<3xsi16> + ``` + }]; +} + #endif // SPIRV_GLSL_OPS diff --git a/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir @@ -30,4 +30,22 @@ %12 = spv.GLSL.Round %arg0 : f32 spv.Return } + + spv.func @fclamp(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { + // CHECK: spv.GLSL.FClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %13 = spv.GLSL.FClamp %arg0, %arg1, %arg2 : f32 + spv.Return + } + + spv.func @uclamp(%arg0 : ui32, %arg1 : ui32, %arg2 : ui32) "None" { + // CHECK: spv.GLSL.UClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : i32 + %13 = spv.GLSL.UClamp %arg0, %arg1, %arg2 : ui32 + spv.Return + } + + spv.func @sclamp(%arg0 : si32, %arg1 : si32, %arg2 : si32) "None" { + // CHECK: spv.GLSL.SClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : si32 + %13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32 + spv.Return + } } diff --git a/mlir/test/Dialect/SPIRV/glslops.mlir b/mlir/test/Dialect/SPIRV/glslops.mlir --- a/mlir/test/Dialect/SPIRV/glslops.mlir +++ b/mlir/test/Dialect/SPIRV/glslops.mlir @@ -269,3 +269,79 @@ %2 = spv.GLSL.Round %arg0 : vector<3xf16> return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.FClamp +//===----------------------------------------------------------------------===// + +func @fclamp(%arg0 : f32, %min : f32, %max : f32) -> () { + // CHECK: spv.GLSL.FClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %2 = spv.GLSL.FClamp %arg0, %min, %max : f32 + return +} + +// ----- + +func @fclamp(%arg0 : vector<3xf32>, %min : vector<3xf32>, %max : vector<3xf32>) -> () { + // CHECK: spv.GLSL.FClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32> + %2 = spv.GLSL.FClamp %arg0, %min, %max : vector<3xf32> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.UClamp +//===----------------------------------------------------------------------===// + +func @fclamp(%arg0 : ui32, %min : ui32, %max : ui32) -> () { + // CHECK: spv.GLSL.UClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : ui32 + %2 = spv.GLSL.UClamp %arg0, %min, %max : ui32 + return +} + +// ----- + +func @fclamp(%arg0 : vector<4xi32>, %min : vector<4xi32>, %max : vector<4xi32>) -> () { + // CHECK: spv.GLSL.UClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<4xi32> + %2 = spv.GLSL.UClamp %arg0, %min, %max : vector<4xi32> + return +} + +// ----- + +func @fclamp(%arg0 : si32, %min : si32, %max : si32) -> () { + // expected-error @+1 {{must be 8/16/32/64-bit signless/unsigned integer or vector}} + %2 = spv.GLSL.UClamp %arg0, %min, %max : si32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.SClamp +//===----------------------------------------------------------------------===// + +func @fclamp(%arg0 : si32, %min : si32, %max : si32) -> () { + // CHECK: spv.GLSL.SClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : si32 + %2 = spv.GLSL.SClamp %arg0, %min, %max : si32 + return +} + +// ----- + +func @fclamp(%arg0 : vector<4xsi32>, %min : vector<4xsi32>, %max : vector<4xsi32>) -> () { + // CHECK: spv.GLSL.SClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<4xsi32> + %2 = spv.GLSL.SClamp %arg0, %min, %max : vector<4xsi32> + return +} + +// ----- + +func @fclamp(%arg0 : i32, %min : i32, %max : i32) -> () { + // expected-error @+1 {{must be 8/16/32/64-bit signed integer or vector}} + %2 = spv.GLSL.SClamp %arg0, %min, %max : i32 + return +}