diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -201,6 +201,33 @@ return success(); } }; + +/// Converts math.powf to SPIRV-Ops. +struct PowFOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dstType = getTypeConverter()->convertType(powfOp.getType()); + if (!dstType) + return failure(); + + // Per GLSL Pow extended instruction spec: + // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." + Location loc = powfOp.getLoc(); + Value zero = + spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter); + Value lessThan = + rewriter.create(loc, adaptor.getLhs(), zero); + Value abs = rewriter.create(loc, adaptor.getLhs()); + Value pow = rewriter.create(loc, abs, adaptor.getRhs()); + Value negate = rewriter.create(loc, pow); + rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -216,7 +243,7 @@ // GLSL patterns patterns .add, - ExpM1OpPattern, + ExpM1OpPattern, PowFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -224,7 +251,6 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -64,20 +64,6 @@ return } -// CHECK-LABEL: @float32_binary_scalar -func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) { - // CHECK: spv.GLSL.Pow %{{.*}}: f32 - %0 = math.powf %lhs, %rhs : f32 - return -} - -// CHECK-LABEL: @float32_binary_vector -func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { - // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32> - %0 = math.powf %lhs, %rhs : vector<4xf32> - return -} - // CHECK-LABEL: @float32_ternary_scalar func.func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) { // CHECK: spv.GLSL.Fma %{{.*}}: f32 @@ -133,6 +119,31 @@ return %0 : vector<2xi32> } +// CHECK-LABEL: @powf_scalar +// CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32) +func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 { + // CHECK: %[[F0:.+]] = spv.Constant 0.000000e+00 : f32 + // CHECK: %[[LT:.+]] = spv.FOrdLessThan %[[LHS]], %[[F0]] : f32 + // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %[[LHS]] : f32 + // CHECK: %[[POW:.+]] = spv.GLSL.Pow %[[ABS]], %[[RHS]] : f32 + // CHECK: %[[NEG:.+]] = spv.FNegate %[[POW]] : f32 + // CHECK: %[[SEL:.+]] = spv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32 + %0 = math.powf %lhs, %rhs : f32 + // CHECK: return %[[SEL]] + return %0: f32 +} + +// CHECK-LABEL: @powf_vector +func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> { + // CHECK: spv.FOrdLessThan + // CHEKC: spv.GLSL.FAbs + // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32> + // CHECK: spv.FNegate + // CHECK: spv.Select + %0 = math.powf %lhs, %rhs : vector<4xf32> + return %0: vector<4xf32> +} + } // end module // -----