diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -354,8 +354,15 @@ PatternRewriter &rewriter) const { auto y = op.getOperand(0); auto x = op.getOperand(1); - if (!getElementTypeOrSelf(x).isF32()) - return rewriter.notifyMatchFailure(op, "unsupported operand type"); + // The expansion below is for F32. For lower precision floats, we can use + // the F32 approximation and truncate. + if (!getElementTypeOrSelf(x).isF32()) { + if (x.getType().getIntOrFloatBitWidth() < 32) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + auto f32Type = x.getType().cast().clone(rewriter.getF32Type()); + x = rewriter.create(op.getLoc(), f32Type, x); + y = rewriter.create(op.getLoc(), f32Type, y); + } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); ArrayRef shape = vectorShape(op.getResult()); @@ -400,6 +407,13 @@ Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); result = builder.create(isNan, cstNan, result); + // Truncate if needed. + if (!getElementTypeOrSelf(op.getResult()).isF32()) { + auto f16Type = op.getResult().getType().cast().clone( + rewriter.getF16Type()); + result = rewriter.create(op.getLoc(), f16Type, result); + } + rewriter.replaceOp(op, result); return success(); }