diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -305,6 +305,25 @@ if (!dstType) return failure(); + // Get the scalar float type. + FloatType scalarFloatType; + if (auto scalarType = powfOp.getType().dyn_cast()) { + scalarFloatType = scalarType; + } else if (auto vectorType = powfOp.getType().dyn_cast()) { + scalarFloatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + // Get int type of the same shape as the float type. + Type scalarIntType = rewriter.getIntegerType(32); + Type intType = scalarIntType; + if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, scalarIntType); + } + // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); @@ -313,9 +332,26 @@ Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); + + // TODO: The following just forcefully casts y into an integer value in + // order to properly propagate the sign, assuming integer y cases. It + // doesn't cover other cases and should be fixed. Cast exponent to integer + // and calculate exponent % 2!=0. + Value intRhs = + rewriter.create(loc, intType, adaptor.getRhs()); + Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter); + Value bitwiseAndOne = + rewriter.create(loc, intRhs, intOne); + Value isOdd = rewriter.create(loc, bitwiseAndOne, intOne); + + // calculate pow based on abs(lhs)^rhs. Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); - rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + // if the exponent is odd and lhs < 0, negate the result. + Value shouldNegate = + rewriter.create(loc, lessThan, isOdd); + rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, + pow); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -137,9 +137,13 @@ // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32 // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32 // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32 + // CHECK: $[[IRHS:.+]] = spirv.ConvertFToS %[[RHS]] : f32 to i32 + // CHECK: $[[REM:.+]] = spirv.BitwseAnd %[[IRHS]] + // CHECK: $[[ODD:.+]] = spirv.IEqual $[[REM]], %cst0_i32 : i32 // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32 // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32 - // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32 + // CHECK: $[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1 + // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32 %0 = math.powf %lhs, %rhs : f32 // CHECK: return %[[SEL]] return %0: f32