diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -305,6 +305,26 @@ if (!dstType) return failure(); + // Establish float type + FloatType floatType; + if (auto scalarType = powfOp.getType().dyn_cast()) { + floatType = scalarType; + } else if (auto vectorType = powfOp.getType().dyn_cast()) { + floatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + // Get int width and shape based on float type + int intWidth = 32; + Type intType = rewriter.getIntegerType(intWidth); + Type scalarIntType = rewriter.getIntegerType(intWidth); + if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, intType); + } + // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); @@ -313,9 +333,33 @@ Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); + + // Cast exponent to integer and calculate exponent % 2!=0. + Value intRhs = + rewriter.create(loc, intType, adaptor.getRhs()); + Value two; + if (auto vectorType = intType.dyn_cast()) { + Type elemType = vectorType.getElementType(); + two = rewriter.create( + loc, intType, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 2).getValue())); + } else { + two = rewriter.create( + loc, intType, rewriter.getIntegerAttr(intType, APInt(intWidth, 2))); + } + Value intZero = spirv::ConstantOp::getZero(intType, loc, rewriter); + Value intMod = rewriter.create(loc, intRhs, two); + Value isOdd = rewriter.create(loc, intMod, intZero); + + // calculate pow based on abs(lhs)^rhs Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); - rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + // if the exponent is odd and lhs < 0, negate the result + Value shouldNegate = + rewriter.create(loc, lessThan, isOdd); + rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, + pow); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir @@ -136,10 +136,14 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 { // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32 // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32 + // CHECK: $[[IRHS:.+]] = spirv.ConvertFToS %[[RHS]] : f32 to i32 + // CHECK: $[[REM:.+]] = spirv.SRem %[[IRHS]], %cst2_i32 : i32 + // CHECK: $[[ODD:.+]] = spirv.INotEqual $[[REM]], %cst0_i32 : i32 // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32 // CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32 // CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32 - // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32 + // CHECK: $[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1 + // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32 %0 = math.powf %lhs, %rhs : f32 // CHECK: return %[[SEL]] return %0: f32