diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -82,15 +82,46 @@ let assemblyFormat = "operands attr-dict `:` type($result)"; } +// Base class for OpenCL binary ops. +class SPV_OCLTernaryOp traits = []> : + SPV_OCLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$x, + SPV_ScalarOrVectorOf:$y, + SPV_ScalarOrVectorOf:$z + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let hasVerifier = 0; +} + +// Base class for OpenCL Ternary arithmetic ops where operand types and +// return type matches. +class SPV_OCLTernaryArithmeticOp traits = []> : + SPV_OCLTernaryOp { + let assemblyFormat = "operands attr-dict `:` type($result)"; +} + + // ----- -def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> { +def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { let summary = [{ - Error function of x encountered in integrating the normal distribution. + Compute the correctly rounded floating-point representation of the sum + of c with the infinitely precise product of a and b. Rounding of + intermediate products shall not occur. Edge case results are per the + IEEE 754-2008 standard. }]; let description = [{ - Result Type and x must be floating-point or vector(2,3,4,8,16) of + Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of floating-point values. All of the operands, including the Result Type operand, must be of the @@ -99,17 +130,13 @@ ``` - float-scalar-vector-type ::= float-type | - `vector<` integer-literal `x` float-type `>` - erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:` + fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` float-scalar-vector-type ```mlir - #### Example: - ``` - %2 = spv.OCL.erf %0 : f32 - %3 = spv.OCL.erf %1 : vector<3xf16> + %0 = spv.OCL.fma %a, %b, %c : f32 + %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> ``` }]; } @@ -179,6 +206,38 @@ // ----- +def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> { + let summary = [{ + Error function of x encountered in integrating the normal distribution. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.erf %0 : f32 + %3 = spv.OCL.erf %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> { let summary = "Exponentiation of Operand 1"; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -92,13 +92,13 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern>( + spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns @@ -109,6 +109,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -76,7 +76,7 @@ return } - // CHECK-LABEL: @float32_ternary_scalar +// CHECK-LABEL: @float32_ternary_scalar func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) { // CHECK: spv.GLSL.Fma %{{.*}}: f32 %0 = math.fma %a, %b, %c : f32 diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -78,4 +78,19 @@ return } +// CHECK-LABEL: @float32_ternary_scalar +func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) { + // CHECK: spv.OCL.fma %{{.*}}: f32 + %0 = math.fma %a, %b, %c : f32 + return +} + +// CHECK-LABEL: @float32_ternary_vector +func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>, + %c: vector<4xf32>) { + // CHECK: spv.OCL.fma %{{.*}}: vector<4xf32> + %0 = math.fma %a, %b, %c : vector<4xf32> + return +} + } // end module diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -166,3 +166,22 @@ return } +// ----- + +//===----------------------------------------------------------------------===// +// spv.OCL.fma +//===----------------------------------------------------------------------===// + +func @fma(%a : f32, %b : f32, %c : f32) -> () { + // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %2 = spv.OCL.fma %a, %b, %c : f32 + return +} + +// ----- + +func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () { + // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32> + %2 = spv.OCL.fma %a, %b, %c : vector<3xf32> + return +} diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -38,4 +38,10 @@ %0 = spv.OCL.fabs %arg0 : vector<16xf32> spv.Return } + + spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" { + // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32 + %13 = spv.OCL.fma %arg0, %arg1, %arg2 : f32 + spv.Return + } } diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -51,7 +51,7 @@ doc = {} if settings.gen_ocl_ops: - section_anchor = spirv.find('h2', {'id': '_a_id_binary_a_binary_form'}) + section_anchor = spirv.find('h2', {'id': '_binary_form'}) for section in section_anchor.parent.find_all('div', {'class': 'sect2'}): for table in section.find_all('table'): inst_html = table.tbody.tr.td