diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -110,36 +110,6 @@ } -// ----- - -def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { - let summary = [{ - Compute the correctly rounded floating-point representation of the sum - of c with the infinitely precise product of a and b. Rounding of - intermediate products shall not occur. Edge case results are per the - IEEE 754-2008 standard. - }]; - - let description = [{ - Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of - floating-point values. - - All of the operands, including the Result Type operand, must be of the - same type. - - - - ``` - fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` - float-scalar-vector-type - ```mlir - - ``` - %0 = spv.OCL.fma %a, %b, %c : f32 - %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> - ``` - }]; -} // ----- @@ -331,6 +301,37 @@ // ----- +def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { + let summary = [{ + Compute the correctly rounded floating-point representation of the sum + of c with the infinitely precise product of a and b. Rounding of + intermediate products shall not occur. Edge case results are per the + IEEE 754-2008 standard. + }]; + + let description = [{ + Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ```mlir + + ``` + %0 = spv.OCL.fma %a, %b, %c : f32 + %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -392,6 +393,38 @@ // ----- +def SPV_OCLRoundOp : SPV_OCLUnaryArithmeticOp<"round", 55, SPV_Float> { + let summary = [{ + Return the integral value nearest to x rounding halfway cases away from + zero, regardless of the current rounding direction. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + round-op ::= ssa-id `=` `spv.OCL.round` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.round %0 : f32 + %3 = spv.OCL.round %0 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> { let summary = "Compute inverse square root of x."; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" @@ -233,6 +234,43 @@ } }; +/// Converts math.round to GLSL SPIRV extended ops. +struct RoundOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = roundOp.getLoc(); + auto operand = roundOp.getOperand(); + auto ty = operand.getType(); + auto ety = getElementTypeOrSelf(ty); + + auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); + auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); + Value half; + if (VectorType vty = ty.dyn_cast()) { + half = rewriter.create( + loc, vty, + DenseElementsAttr::get(vty, + rewriter.getFloatAttr(ety, 0.5).getValue())); + } else { + half = rewriter.create( + loc, ty, rewriter.getFloatAttr(ety, 0.5)); + } + + auto abs = rewriter.create(loc, operand); + auto floor = rewriter.create(loc, abs); + auto sub = rewriter.create(loc, abs, floor); + auto greater = + rewriter.create(loc, sub, half); + auto select = rewriter.create(loc, greater, one, zero); + auto add = rewriter.create(loc, floor, select); + rewriter.replaceOpWithNewOp(roundOp, add, operand); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -248,7 +286,7 @@ // GLSL patterns patterns .add, - ExpM1OpPattern, PowFOpPattern, + ExpM1OpPattern, PowFOpPattern, RoundOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -273,6 +311,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -145,6 +145,38 @@ return %0: vector<4xf32> } +// CHECK-LABEL: @round_scalar +func.func @round_scalar(%x: f32) -> f32 { + // CHECK: %[[ZERO:.+]] = spv.Constant 0.000000e+00 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 + // CHECK: %[[HALF:.+]] = spv.Constant 5.000000e-01 + // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0 + // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]] + // CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]] + // CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]] + // CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]] + // CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]] + // CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]] + %0 = math.round %x : f32 + return %0: f32 +} + +// CHECK-LABEL: @round_vector +func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> { + // CHECK: %[[ZERO:.+]] = spv.Constant dense<0.000000e+00> + // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> + // CHECK: %[[HALF:.+]] = spv.Constant dense<5.000000e-01> + // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0 + // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[ABS]] + // CHECK: %[[SUB:.+]] = spv.FSub %[[ABS]], %[[FLOOR]] + // CHECK: %[[GE:.+]] = spv.FOrdGreaterThanEqual %[[SUB]], %[[HALF]] + // CHECK: %[[SEL:.+]] = spv.Select %[[GE]], %[[ONE]], %[[ZERO]] + // CHECK: %[[ADD:.+]] = spv.FAdd %[[FLOOR]], %[[SEL]] + // CHECK: %[[BITCAST:.+]] = spv.Bitcast %[[ADD]] + %0 = math.round %x : vector<4xf32> + return %0: vector<4xf32> +} + } // end module // ----- diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -34,6 +34,8 @@ %11 = math.floor %arg0 : f32 // CHECK: spv.OCL.erf %{{.*}}: f32 %12 = math.erf %arg0 : f32 + // CHECK: spv.OCL.round %{{.*}}: f32 + %13 = math.round %arg0 : f32 return }