diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -165,6 +165,16 @@ // Helper functions to create constants. //----------------------------------------------------------------------------// +static Value floatCst(ImplicitLocOpBuilder &builder, float value, + Type elementType) { + assert(elementType.isF16() || + elementType.isF32() && "x must be f16 or f32 type."); + if (elementType.isF32()) + return builder.create(builder.getF32FloatAttr(value)); + else + return builder.create(builder.getF16FloatAttr(value)); +} + static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { return builder.create(builder.getF32FloatAttr(value)); } @@ -270,11 +280,13 @@ namespace { Value makePolynomialCalculation(ImplicitLocOpBuilder &builder, llvm::ArrayRef coeffs, Value x) { - assert(getElementTypeOrSelf(x).isF32() && "x must be f32 type"); + Type elementType = getElementTypeOrSelf(x); + assert(elementType.isF32() || + elementType.isF16() && "x must be f32 or f16 type"); ArrayRef shape = vectorShape(x); if (coeffs.empty()) - return broadcast(builder, f32Cst(builder, 0.0f), shape); + return broadcast(builder, floatCst(builder, 0.0f, elementType), shape); if (coeffs.size() == 1) return coeffs[0]; @@ -771,10 +783,13 @@ LogicalResult ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const { - if (!getElementTypeOrSelf(op.getOperand()).isF32()) - return rewriter.notifyMatchFailure(op, "unsupported operand type"); + Value operand = op.getOperand(); + Type elementType = getElementTypeOrSelf(operand); - ArrayRef shape = vectorShape(op.getOperand()); + assert(elementType.isF32() || + elementType.isF16() && "x must be f32 or f16 type"); + + ArrayRef shape = vectorShape(operand); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); auto bcast = [&](Value value) -> Value { @@ -784,57 +799,56 @@ const int intervalsCount = 3; const int polyDegree = 4; - Value zero = bcast(f32Cst(builder, 0)); - Value one = bcast(f32Cst(builder, 1)); + Value zero = bcast(floatCst(builder, 0, elementType)); + Value one = bcast(floatCst(builder, 1, elementType)); Value pp[intervalsCount][polyDegree + 1]; - pp[0][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); - pp[0][1] = bcast(f32Cst(builder, +1.12837916222975858e+00f)); - pp[0][2] = bcast(f32Cst(builder, -5.23018562988006470e-01f)); - pp[0][3] = bcast(f32Cst(builder, +2.09741709609267072e-01f)); - pp[0][4] = bcast(f32Cst(builder, +2.58146801602987875e-02f)); - pp[1][0] = bcast(f32Cst(builder, +0.00000000000000000e+00f)); - pp[1][1] = bcast(f32Cst(builder, +1.12750687816789140e+00f)); - pp[1][2] = bcast(f32Cst(builder, -3.64721408487825775e-01f)); - pp[1][3] = bcast(f32Cst(builder, +1.18407396425136952e-01f)); - pp[1][4] = bcast(f32Cst(builder, +3.70645533056476558e-02f)); - pp[2][0] = bcast(f32Cst(builder, -3.30093071049483172e-03f)); - pp[2][1] = bcast(f32Cst(builder, +3.51961938357697011e-03f)); - pp[2][2] = bcast(f32Cst(builder, -1.41373622814988039e-03f)); - pp[2][3] = bcast(f32Cst(builder, +2.53447094961941348e-04f)); - pp[2][4] = bcast(f32Cst(builder, -1.71048029455037401e-05f)); + pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); + pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType)); + pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType)); + pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType)); + pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType)); + pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType)); + pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType)); + pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType)); + pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType)); + pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType)); + pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType)); + pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType)); + pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType)); + pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType)); + pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType)); Value qq[intervalsCount][polyDegree + 1]; - qq[0][0] = bcast(f32Cst(builder, +1.000000000000000000e+00f)); - qq[0][1] = bcast(f32Cst(builder, -4.635138185962547255e-01f)); - qq[0][2] = bcast(f32Cst(builder, +5.192301327279782447e-01f)); - qq[0][3] = bcast(f32Cst(builder, -1.318089722204810087e-01f)); - qq[0][4] = bcast(f32Cst(builder, +7.397964654672315005e-02f)); - qq[1][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); - qq[1][1] = bcast(f32Cst(builder, -3.27607011824493086e-01f)); - qq[1][2] = bcast(f32Cst(builder, +4.48369090658821977e-01f)); - qq[1][3] = bcast(f32Cst(builder, -8.83462621207857930e-02f)); - qq[1][4] = bcast(f32Cst(builder, +5.72442770283176093e-02f)); - qq[2][0] = bcast(f32Cst(builder, +1.00000000000000000e+00f)); - qq[2][1] = bcast(f32Cst(builder, -2.06069165953913769e+00f)); - qq[2][2] = bcast(f32Cst(builder, +1.62705939945477759e+00f)); - qq[2][3] = bcast(f32Cst(builder, -5.83389859211130017e-01f)); - qq[2][4] = bcast(f32Cst(builder, +8.21908939856640930e-02f)); + qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType)); + qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType)); + qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType)); + qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType)); + qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType)); + qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); + qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType)); + qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType)); + qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType)); + qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType)); + qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType)); + qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType)); + qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType)); + qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType)); + qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType)); Value offsets[intervalsCount]; - offsets[0] = bcast(f32Cst(builder, 0.0f)); - offsets[1] = bcast(f32Cst(builder, 0.0f)); - offsets[2] = bcast(f32Cst(builder, 1.0f)); + offsets[0] = bcast(floatCst(builder, 0.0f, elementType)); + offsets[1] = bcast(floatCst(builder, 0.0f, elementType)); + offsets[2] = bcast(floatCst(builder, 1.0f, elementType)); Value bounds[intervalsCount]; - bounds[0] = bcast(f32Cst(builder, 0.8f)); - bounds[1] = bcast(f32Cst(builder, 2.0f)); - bounds[2] = bcast(f32Cst(builder, 3.75f)); - - Value isNegativeArg = builder.create(arith::CmpFPredicate::OLT, - op.getOperand(), zero); - Value negArg = builder.create(op.getOperand()); - Value x = - builder.create(isNegativeArg, negArg, op.getOperand()); + bounds[0] = bcast(floatCst(builder, 0.8f, elementType)); + bounds[1] = bcast(floatCst(builder, 2.0f, elementType)); + bounds[2] = bcast(floatCst(builder, 3.75f, elementType)); + + Value isNegativeArg = + builder.create(arith::CmpFPredicate::OLT, operand, zero); + Value negArg = builder.create(operand); + Value x = builder.create(isNegativeArg, negArg, operand); Value offset = offsets[0]; Value p[polyDegree + 1];