diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td @@ -332,6 +332,67 @@ // ----- +def SPV_CLFMaxOp : SPV_CLBinaryArithmeticOp<"fmax", 27, SPV_Float> { + let summary = "Return maximum of two floating-point operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x. If one argument is a NaN, + Fmax returns the other argument. If both arguments are NaNs, Fmax returns a NaN. + + Result Type, x and y must be floating-point or vector(2,3,4,8,16) + of floating-point values. + + All of the operands, including the Result Type operand, + must be of the same type. + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmax-op ::= ssa-id `=` `spv.CL.fmax` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.CL.fmax %0, %1 : f32 + %3 = spv.CL.fmax %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_CLFMinOp : SPV_CLBinaryArithmeticOp<"fmin", 28, SPV_Float> { + let summary = "Return minimum of two floating-point operands"; + + let description = [{ + Returns y if y < x, otherwise it returns x. If one argument is a NaN, Fmin returns the other argument. + If both arguments are NaNs, Fmin returns a NaN. + + Result Type,x and y must be floating-point or vector(2,3,4,8,16) of floating-point values. + + All of the operands, including the Result Type operand, must be of the same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmin-op ::= ssa-id `=` `spv.CL.fmin` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.CL.fmin %0, %1 : f32 + %3 = spv.CL.fmin %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_CLLogOp : SPV_CLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -573,4 +634,110 @@ }]; } +// ----- + +def SPV_CLSMaxOp : SPV_CLBinaryArithmeticOp<"s_max", 156, SPV_Integer> { + let summary = "Return maximum of two signed integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smax-op ::= ssa-id `=` `spv.CL.s_max` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.s_max %0, %1 : i32 + %3 = spv.CL.s_max %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_CLUMaxOp : SPV_CLBinaryArithmeticOp<"u_max", 157, SPV_Integer> { + let summary = "Return maximum of two unsigned integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + umax-op ::= ssa-id `=` `spv.CL.u_max` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.u_max %0, %1 : i32 + %3 = spv.CL.u_max %0, %1 : vector<3xi16> + ``` + }]; +} + +def SPV_CLSMinOp : SPV_CLBinaryArithmeticOp<"s_min", 158, SPV_Integer> { + let summary = "Return minimum of two signed integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as signed integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + smin-op ::= ssa-id `=` `spv.CL.s_min` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.s_min %0, %1 : i32 + %3 = spv.CL.s_min %0, %1 : vector<3xi16> + ``` + }]; +} + +// ----- + +def SPV_CLUMinOp : SPV_CLBinaryArithmeticOp<"u_min", 159, SPV_Integer> { + let summary = "Return minimum of two unsigned integer operands"; + + let description = [{ + Returns y if x < y, otherwise it returns x, where x and y are treated as unsigned integers. + + Result Type,x and y must be integer or vector(2,3,4,8,16) of integer values. + + All of the operands, including the Result Type operand, must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + umin-op ::= ssa-id `=` `spv.CL.u_min` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + ```mlir + %2 = spv.CL.u_min %0, %1 : i32 + %3 = spv.CL.u_min %0, %1 : vector<3xi16> + ``` + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_CL_OPS diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -934,7 +934,14 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern + spirv::ElementwiseOpPattern, + + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern >(typeConverter, patterns.getContext()); // clang-format on diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -970,12 +970,58 @@ // ----- -// Check OpenCL lowering of arith.remsi +// Check various lowerings for OpenCL. module attributes { spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { +// Check integer operation conversions. +// CHECK-LABEL: @int32_scalar +func.func @int32_scalar(%lhs: i32, %rhs: i32) { + // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32 + %0 = arith.addi %lhs, %rhs: i32 + // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32 + %1 = arith.subi %lhs, %rhs: i32 + // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32 + %2 = arith.muli %lhs, %rhs: i32 + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 + %3 = arith.divsi %lhs, %rhs: i32 + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 + %4 = arith.divui %lhs, %rhs: i32 + // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 + %5 = arith.remui %lhs, %rhs: i32 + // CHECK: spv.CL.s_max %{{.*}}, %{{.*}}: i32 + %6 = arith.maxsi %lhs, %rhs : i32 + // CHECK: spv.CL.u_max %{{.*}}, %{{.*}}: i32 + %7 = arith.maxui %lhs, %rhs : i32 + // CHECK: spv.CL.s_min %{{.*}}, %{{.*}}: i32 + %8 = arith.minsi %lhs, %rhs : i32 + // CHECK: spv.CL.u_min %{{.*}}, %{{.*}}: i32 + %9 = arith.minui %lhs, %rhs : i32 + return +} + +// Check float binary operation conversions. +// CHECK-LABEL: @float32_binary_scalar +func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) { + // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 + %0 = arith.addf %lhs, %rhs: f32 + // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32 + %1 = arith.subf %lhs, %rhs: f32 + // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 + %2 = arith.mulf %lhs, %rhs: f32 + // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32 + %3 = arith.divf %lhs, %rhs: f32 + // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 + %4 = arith.remf %lhs, %rhs: f32 + // CHECK: spv.CL.fmax %{{.*}}, %{{.*}}: f32 + %5 = arith.maxf %lhs, %rhs: f32 + // CHECK: spv.CL.fmin %{{.*}}, %{{.*}}: f32 + %6 = arith.minf %lhs, %rhs: f32 + return +} + // CHECK-LABEL: @scalar_srem // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) func.func @scalar_srem(%lhs: i32, %rhs: i32) { diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -185,3 +185,45 @@ %2 = spv.CL.fma %a, %b, %c : vector<3xf32> return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.CL.{F|S|U}{Max|Min} +//===----------------------------------------------------------------------===// + +func.func @fmaxmin(%arg0 : f32, %arg1 : f32) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f32 + %1 = spv.CL.fmax %arg0, %arg1 : f32 + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f32 + %2 = spv.CL.fmin %arg0, %arg1 : f32 + return +} + +func.func @fmaxminvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : vector<3xf16> + %1 = spv.CL.fmax %arg0, %arg1 : vector<3xf16> + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : vector<3xf16> + %2 = spv.CL.fmin %arg0, %arg1 : vector<3xf16> + return +} + +func.func @fmaxminf64(%arg0 : f64, %arg1 : f64) { + // CHECK: spv.CL.fmax {{%.*}}, {{%.*}} : f64 + %1 = spv.CL.fmax %arg0, %arg1 : f64 + // CHECK: spv.CL.fmin {{%.*}}, {{%.*}} : f64 + %2 = spv.CL.fmin %arg0, %arg1 : f64 + return +} + +func.func @iminmax(%arg0: i32, %arg1: i32) { + // CHECK: spv.CL.s_max {{%.*}}, {{%.*}} : i32 + %1 = spv.CL.s_max %arg0, %arg1 : i32 + // CHECK: spv.CL.u_max {{%.*}}, {{%.*}} : i32 + %2 = spv.CL.u_max %arg0, %arg1 : i32 + // CHECK: spv.CL.s_min {{%.*}}, {{%.*}} : i32 + %3 = spv.CL.s_min %arg0, %arg1 : i32 + // CHECK: spv.CL.u_min {{%.*}}, {{%.*}} : i32 + %4 = spv.CL.u_min %arg0, %arg1 : i32 + return +} diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -44,4 +44,21 @@ %13 = spv.CL.fma %arg0, %arg1, %arg2 : f32 spv.Return } + + spv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" { + // CHECK: {{%.*}} = spv.CL.fmax {{%.*}}, {{%.*}} : f32 + %1 = spv.CL.fmax %arg0, %arg1 : f32 + // CHECK: {{%.*}} = spv.CL.s_max {{%.*}}, {{%.*}} : i32 + %2 = spv.CL.s_max %arg2, %arg3 : i32 + // CHECK: {{%.*}} = spv.CL.u_max {{%.*}}, {{%.*}} : i32 + %3 = spv.CL.u_max %arg2, %arg3 : i32 + + // CHECK: {{%.*}} = spv.CL.fmin {{%.*}}, {{%.*}} : f32 + %4 = spv.CL.fmin %arg0, %arg1 : f32 + // CHECK: {{%.*}} = spv.CL.s_min {{%.*}}, {{%.*}} : i32 + %5 = spv.CL.s_min %arg2, %arg3 : i32 + // CHECK: {{%.*}} = spv.CL.u_min {{%.*}}, {{%.*}} : i32 + %6 = spv.CL.u_min %arg2, %arg3 : i32 + spv.Return + } }