diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -86,8 +86,10 @@ // ----- -def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> { - let summary = "Compute hyperbolic tangent of x radians."; +def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> { + let summary = [{ + Error function of x encountered in integrating the normal distribution. + }]; let description = [{ Result Type and x must be floating-point or vector(2,3,4,8,16) of @@ -101,15 +103,15 @@ ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:` + erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:` float-scalar-vector-type ```mlir #### Example: ``` - %2 = spv.OCL.tanh %0 : f32 - %3 = spv.OCL.tanh %1 : vector<3xf16> + %2 = spv.OCL.erf %0 : f32 + %3 = spv.OCL.erf %1 : vector<3xf16> ``` }]; } @@ -423,6 +425,36 @@ // ----- +def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> { + let summary = "Compute hyperbolic tangent of x radians."; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.tanh %0 : f32 + %3 = spv.OCL.tanh %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLSAbsOp : SPV_OCLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> { let summary = "Absolute value of operand"; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -84,6 +84,7 @@ spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -28,6 +28,8 @@ %9 = math.ceil %arg0 : f32 // CHECK: spv.OCL.floor %{{.*}}: f32 %10 = math.floor %arg0 : f32 + // CHECK: spv.OCL.erf %{{.*}}: f32 + %11 = math.erf %arg0 : f32 return } diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -22,6 +22,8 @@ %8 = spv.OCL.pow %arg0, %arg0 : f32 // CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32 %9 = spv.OCL.rsqrt %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.erf {{%.*}} : f32 + %10 = spv.OCL.erf %arg0 : f32 spv.Return }