diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td @@ -110,36 +110,6 @@ } -// ----- - -def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { - let summary = [{ - Compute the correctly rounded floating-point representation of the sum - of c with the infinitely precise product of a and b. Rounding of - intermediate products shall not occur. Edge case results are per the - IEEE 754-2008 standard. - }]; - - let description = [{ - Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of - floating-point values. - - All of the operands, including the Result Type operand, must be of the - same type. - - - - ``` - fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` - float-scalar-vector-type - ```mlir - - ``` - %0 = spv.OCL.fma %a, %b, %c : f32 - %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> - ``` - }]; -} // ----- @@ -331,6 +301,37 @@ // ----- +def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> { + let summary = [{ + Compute the correctly rounded floating-point representation of the sum + of c with the infinitely precise product of a and b. Rounding of + intermediate products shall not occur. Edge case results are per the + IEEE 754-2008 standard. + }]; + + let description = [{ + Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:` + float-scalar-vector-type + ```mlir + + ``` + %0 = spv.OCL.fma %a, %b, %c : f32 + %1 = spv.OCL.fma %a, %b, %c : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -392,6 +393,38 @@ // ----- +def SPV_OCLRoundOp : SPV_OCLUnaryArithmeticOp<"round", 55, SPV_Float> { + let summary = [{ + Return the integral value nearest to x rounding halfway cases away from + zero, regardless of the current rounding direction. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + round-op ::= ssa-id `=` `spv.OCL.round` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.round %0 : f32 + %3 = spv.OCL.round %0 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> { let summary = "Compute inverse square root of x."; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" @@ -233,6 +234,39 @@ } }; +/// Converts math.round to SPIRV-Ops for GLSL. +struct RoundOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = roundOp.getLoc(); + auto operand = roundOp.getOperand(); + auto ty = operand.getType(); + auto ety = getElementTypeOrSelf(ty); + + Value two; + if (VectorType vty = ty.dyn_cast()) { + two = rewriter.create( + loc, vty, + DenseElementsAttr::get(vty, + rewriter.getFloatAttr(ety, 2.0).getValue())); + } else { + two = rewriter.create(loc, ty, + rewriter.getFloatAttr(ety, 2.0)); + } + + auto abs = rewriter.create(loc, operand); + auto mul = rewriter.create(loc, abs, two); + auto floor = rewriter.create(loc, mul); + auto div = rewriter.create(loc, floor, two); + auto ceil = rewriter.create(loc, div); + rewriter.replaceOpWithNewOp(roundOp, ceil, operand); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -248,7 +282,7 @@ // GLSL patterns patterns .add, - ExpM1OpPattern, PowFOpPattern, + ExpM1OpPattern, PowFOpPattern, RoundOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -273,6 +307,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, + spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -54,6 +54,21 @@ return success(); } +static LogicalResult convertRoundOp(math::RoundOp op, + PatternRewriter &rewriter) { + auto operand = op.getOperand(); + auto ty = operand.getType(); + auto loc = operand.getLoc(); + + auto floatHalf = rewriter.getFloatAttr(ty, 0.5); + Value half = rewriter.create(loc, floatHalf); + + Value copysign = rewriter.create(loc, half, operand); + Value add = rewriter.create(loc, operand, copysign); + rewriter.replaceOpWithNewOp(op, add); + return success(); +} + static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter) { auto operand = op.getOperand(); @@ -115,6 +130,10 @@ patterns.add(convertCtlzOp); } +void mlir::populateExpandRoundPattern(RewritePatternSet &patterns) { + patterns.add(convertRoundOp); +} + void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -145,6 +145,30 @@ return %0: vector<4xf32> } +// CHECK-LABEL: @round_scalar +func.func @round_scalar(%x: f32) -> f32 { + // CHECK: %[[TWO:.+]] = spv.Constant 2.000000e+00 + // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0 + // CHECK: %[[MUL:.+]] = spv.FMul %[[ABS]], %[[TWO]] + // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[MUL]] + // CHECK: %[[DIV:.+]] = spv.FDiv %[[FLOOR]], %[[TWO]] + // CHECK: %[[CEIL:.+]] = spv.GLSL.Ceil %[[DIV]] + %0 = math.round %x : f32 + return %0: f32 +} + +// CHECK-LABEL: @round_vector +func.func @round_vector(%x: vector<4xf32>) -> vector<4xf32> { + // CHECK: %[[TWO:.+]] = spv.Constant dense<2.000000e+00> + // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %arg0 + // CHECK: %[[MUL:.+]] = spv.FMul %[[ABS]], %[[TWO]] + // CHECK: %[[FLOOR:.+]] = spv.GLSL.Floor %[[MUL]] + // CHECK: %[[DIV:.+]] = spv.FDiv %[[FLOOR]], %[[TWO]] + // CHECK: %[[CEIL:.+]] = spv.GLSL.Ceil %[[DIV]] + %0 = math.round %x : vector<4xf32> + return %0: vector<4xf32> +} + } // end module // ----- diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -34,6 +34,8 @@ %11 = math.floor %arg0 : f32 // CHECK: spv.OCL.erf %{{.*}}: f32 %12 = math.erf %arg0 : f32 + // CHECK: spv.OCL.round %{{.*}}: f32 + %13 = math.round %arg0 : f32 return }