diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -305,6 +305,25 @@ if (!dstType) return failure(); + // Get the scalar float type. + FloatType scalarFloatType; + if (auto scalarType = powfOp.getType().dyn_cast()) { + scalarFloatType = scalarType; + } else if (auto vectorType = powfOp.getType().dyn_cast()) { + scalarFloatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + // Get int type of the same shape as the float type. + Type scalarIntType = rewriter.getIntegerType(32); + Type intType = scalarIntType; + if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, scalarIntType); + } + // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); @@ -313,9 +332,24 @@ Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); + + // TODO: The following just forcefully casts y into an integer value in order to properly propagate the sign, assuming integer y cases. It doesn't cover other cases and should be fixed. + // Cast exponent to integer and calculate exponent % 2!=0. + Value intRhs = + rewriter.create(loc, intType, adaptor.getRhs()); + Value intOne = + spirv::ConstantOp::getOne(intType, loc, rewriter); + Value bitwiseAndOne = rewriter.create(loc, intRhs, intOne); + Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); + + // calculate pow based on abs(lhs)^rhs Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); - rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + // if the exponent is odd and lhs < 0, negate the result + Value shouldNegate = + rewriter.create(loc, lessThan, isOdd); + rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, + pow); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -134,12 +134,28 @@ // CHECK-LABEL: @powf_scalar // CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32) func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 { + %cst_f32 = spirv.Constant 0.000000e+00 : f32 + %0 = spirv.FOrdLessThan %arg0, %cst_f32 : f32 + %1 = spirv.GL.FAbs %arg0 : f32 + %2 = spirv.ConvertFToS %arg1 : f32 to i32 + %cst1_i32 = spirv.Constant 1 : i32 + %3 = spirv.BitwiseAnd %2, %cst1_i32 : i32 + %4 = spirv.IEqual %3, %cst1_i32 : i32 + %5 = spirv.GL.Pow %1, %arg1 : f32 + %6 = spirv.FNegate %5 : f32 + %7 = spirv.LogicalAnd %0, %4 : i1 + %8 = spirv.Select %7, %6, %5 : i1, f32 + return %8 : f32 // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32 // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32 // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32 + // CHECK: $[[IRHS:.+]] = spirv.ConvertFToS %[[RHS]] : f32 to i32 + // CHECK: $[[REM:.+]] = spirv.BitwseAnd %[[IRHS]] + // CHECK: $[[ODD:.+]] = spirv.IEqual $[[REM]], %cst0_i32 : i32 // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32 // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32 - // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32 + // CHECK: $[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1 + // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32 %0 = math.powf %lhs, %rhs : f32 // CHECK: return %[[SEL]] return %0: f32