diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -305,6 +305,25 @@ if (!dstType) return failure(); + // Establish float type + FloatType floatType; + if (auto scalarType = powfOp.getType().dyn_cast()) { + floatType = scalarType; + } else if (auto vectorType = powfOp.getType().dyn_cast()) { + floatType = vectorType.getElementType().cast(); + } else { + return failure(); + } + + // Get int width and shape based on float type + Type intType = rewriter.getIntegerType(floatType.getWidth()); + Type scalarIntType = rewriter.getIntegerType(floatType.getWidth()); + if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + assert(vectorType.getRank() == 1); + int count = vectorType.getNumElements(); + intType = VectorType::get(count, intType); + } + // Per GL Pow extended instruction spec: // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." Location loc = powfOp.getLoc(); @@ -313,9 +332,34 @@ Value lessThan = rewriter.create(loc, adaptor.getLhs(), zero); Value abs = rewriter.create(loc, adaptor.getLhs()); + + // Cast exponent to integer and calculate exponent % 2!=0. + Value intRhs = + rewriter.create(loc, intType, adaptor.getRhs()); + Value two; + if (auto vectorType = intType.dyn_cast()) { + Type elemType = vectorType.getElementType(); + two = rewriter.create( + loc, intType, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 2).getValue())); + } else { + two = rewriter.create( + loc, intType, + rewriter.getIntegerAttr(intType, APInt(floatType.getWidth(), 2))); + } + Value intZero = spirv::ConstantOp::getZero(intType, loc, rewriter); + Value intMod = rewriter.create(loc, intRhs, two); + Value isOdd = rewriter.create(loc, intMod, intZero); + + // calculate pow based on abs(lhs)^rhs Value pow = rewriter.create(loc, abs, adaptor.getRhs()); Value negate = rewriter.create(loc, pow); - rewriter.replaceOpWithNewOp(powfOp, lessThan, negate, pow); + // if the exponent is odd and lhs < 0, negate the result + Value shouldNegate = + rewriter.create(loc, lessThan, isOdd); + rewriter.replaceOpWithNewOp(powfOp, shouldNegate, negate, + pow); return success(); } };