diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -305,6 +305,24 @@ if (!dstType) return failure(); + // Get the scalar float type. + FloatType scalarFloatType; + if (auto scalarType = powfOp.getType().dyn_cast()) { + scalarFloatType = scalarType; + } else if (auto vectorType = powfOp.getType().dyn_cast()) { + scalarFloatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + // Get int type of the same shape as the float type. + Type scalarIntType = rewriter.getIntegerType(32); + Type intType = scalarIntType; + if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + auto shape = vectorType.getShape(); + intType = VectorType::get(shape, scalarIntType); + } + // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); @@ -313,9 +331,27 @@ Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); + + // TODO: The following just forcefully casts y into an integer value in + // order to properly propagate the sign, assuming integer y cases. It + // doesn't cover other cases and should be fixed. + + // Cast exponent to integer and calculate exponent % 2 != 0. + Value intRhs = + rewriter.create(loc, intType, adaptor.getRhs()); + Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); + Value bitwiseAndOne = + rewriter.create(loc, intRhs, intOne); + Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); + + // calculate pow based on abs(lhs)^rhs. Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); - rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + // if the exponent is odd and lhs < 0, negate the result. + Value shouldNegate = + rewriter.create(loc, lessThan, isOdd); + rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, + pow); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -137,9 +137,14 @@ // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32 // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32 // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32 + // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS + // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32 + // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]] + // CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32 // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32 // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32 - // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32 + // CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1 + // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32 %0 = math.powf %lhs, %rhs : f32 // CHECK: return %[[SEL]] return %0: f32 @@ -149,6 +154,8 @@ func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> { // CHECK: spirv.FOrdLessThan // CHECK: spirv.GL.FAbs + // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32> + // CHECK: spirv.IEqual %{{.*}} : vector<4xi32> // CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32> // CHECK: spirv.FNegate // CHECK: spirv.Select