diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -10,7 +10,6 @@ // that do not rely on any of the library functions. // //===----------------------------------------------------------------------===// - #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -20,6 +19,7 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include using namespace mlir; using namespace mlir::vector; @@ -28,6 +28,8 @@ static bool isF32(Type type) { return type.isF32(); } +static bool isI32(Type type) { return type.isInteger(32); } + // Returns vector width if the element type is matching the predicate (scalars // that do match the predicate have width equal to `1`). static Optional vectorWidth(Type type, TypePredicate pred) { @@ -153,6 +155,30 @@ return {normalizedFraction, exponent}; } +// Computes exp2 for an i32 argument. +static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) { + assert(isI32(elementType(arg.getType())) && "argument must be i32 type"); + + int width = vectorWidth(arg.getType()); + + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, width); + }; + + auto f32Vec = broadcast(builder.getF32Type(), width); + // The exponent of f32 located at 23-bit. + auto exponetBitLocation = bcast(i32Cst(builder, 23)); + // Set the exponent bias to zero. + auto bias = bcast(i32Cst(builder, 127)); + + Value biasedArg = builder.create(arg, bias); + Value exp2ValueInt = + builder.create(biasedArg, exponetBitLocation); + Value exp2ValueF32 = builder.create(f32Vec, exp2ValueInt); + + return exp2ValueF32; +} + //----------------------------------------------------------------------------// // TanhOp approximation. //----------------------------------------------------------------------------// @@ -230,6 +256,11 @@ return success(); } +#define LN2_VALUE \ + 0.693147180559945309417232121458176568075500134360255254120680009493393621L +#define LN2E_VALUE \ + 1.442695040888963407359924681001892137426645954152985934135449406931109219L + //----------------------------------------------------------------------------// // LogOp approximation. //----------------------------------------------------------------------------// @@ -247,9 +278,6 @@ }; } // namespace -#define LN2_VALUE \ - 0.693147180559945309417232121458176568075500134360255254120680009493393621L - LogicalResult LogApproximation::matchAndRewrite(math::LogOp op, PatternRewriter &rewriter) const { @@ -353,9 +381,125 @@ return success(); } +//----------------------------------------------------------------------------// +// Exp approximation. +//----------------------------------------------------------------------------// + +namespace { + +struct ExpApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::ExpOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +// Approximate exp(x) using its reduced range exp(y) where y is in the range [0, +// ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) = exp(y) +// * 2^k. exp(y). +LogicalResult +ExpApproximation::matchAndRewrite(math::ExpOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + + // TODO: Consider a common pattern rewriter with all methods below to + // write the approximations. + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + auto f32_fmla = [&](Value a, Value b, Value c) { + return builder.create(a, b, c); + }; + auto f32_mul = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto f32_sub = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto floor_f32 = [&](Value a) { return builder.create(a); }; + + Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); + Value cstLN2E = bcast(f32Cst(builder, static_cast(LN2E_VALUE))); + + // Polynomial coefficients. + Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); + Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); + Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); + Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); + Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); + Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); + + Value x = op.operand(); + + // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) + Value xL2Inv = f32_mul(x, cstLN2E); + Value kF32 = floor_f32(xL2Inv); + Value kLn2 = f32_mul(kF32, cstLn2); + Value y = f32_sub(x, kLn2); + + // Use Estrin's evaluation scheme with 3 independent parts: + // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 + Value y2 = f32_mul(y, y); + Value y4 = f32_mul(y2, y2); + + Value q0 = f32_fmla(cstCephesExpP1, y, cstCephesExpP0); + Value q1 = f32_fmla(cstCephesExpP3, y, cstCephesExpP2); + Value q2 = f32_fmla(cstCephesExpP5, y, cstCephesExpP4); + Value expY = f32_fmla(q1, y2, q0); + expY = f32_fmla(q2, y4, expY); + + auto i32Vec = broadcast(builder.getI32Type(), *width); + + // exp2(k) + Value k = builder.create(kF32, i32Vec); + Value exp2KValue = exp2I32(builder, k); + + // exp(x) = exp(y) * exp2(k) + expY = f32_mul(expY, exp2KValue); + + // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its + // partitioned as the following: + // exp(x) = 0, x <= -inf + // exp(x) = underflow (min_float), x <= -88 + // exp(x) = inf (min_float), x >= 88 + // Note: |k| = 127 is the value where the 8-bits exponent saturates. + Value zerof32Const = bcast(f32Cst(builder, 0)); + auto constPosInfinity = + bcast(f32Cst(builder, std::numeric_limits::infinity())); + auto constNegIfinity = + bcast(f32Cst(builder, -std::numeric_limits::infinity())); + auto underflow = bcast(f32Cst(builder, std::numeric_limits::min())); + + Value kMaxConst = bcast(i32Cst(builder, 127)); + Value kMaxNegConst = bcast(i32Cst(builder, -127)); + Value rightBound = builder.create(CmpIPredicate::sle, k, kMaxConst); + Value leftBound = builder.create(CmpIPredicate::sge, k, kMaxNegConst); + + Value isNegInfinityX = + builder.create(CmpFPredicate::OEQ, x, constNegIfinity); + Value isPostiveX = + builder.create(CmpFPredicate::OGT, x, zerof32Const); + Value isComputable = builder.create(rightBound, leftBound); + + expY = builder.create( + isComputable, expY, + builder.create( + isPostiveX, constPosInfinity, + builder.create(isNegInfinityX, zerof32Const, underflow))); + + rewriter.replaceOp(op, expY); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -20,3 +20,16 @@ %1 = math.log %0 : vector<8xf32> return %1 : vector<8xf32> } + +// CHECK-LABEL: @exp_scalar +func @exp_scalar(%arg0: f32) -> f32 { + %0 = math.exp %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @exp_vector +func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + // CHECK-NOT: math.exp + %0 = math.exp %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} \ No newline at end of file diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -71,8 +71,46 @@ return } +// -------------------------------------------------------------------------- // +// Log. +// -------------------------------------------------------------------------- // +func @exp() { + // CHECK: 2.71828 + %0 = constant 1.0 : f32 + %1 = math.exp %0 : f32 + vector.print %1 : f32 + + // CHECK: 0.778802, 2.117, 2.71828, 3.85742 + %2 = constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32> + %3 = math.exp %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: 1 + %zero = constant 0.0 : f32 + %exp_zero = math.exp %zero : f32 + vector.print %exp_zero : f32 + + // CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf + %special_vec = constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32> + %exp_special_vec = math.exp %special_vec : vector<4xf32> + vector.print %exp_special_vec : vector<4xf32> + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %exp_inf = math.exp %inf : f32 + vector.print %exp_inf : f32 + + // CHECK: 0 + %negative_inf = constant 0xff800000 : f32 + %exp_negative_inf = math.exp %negative_inf : f32 + vector.print %exp_negative_inf : f32 + + return +} + func @main() { call @tanh(): () -> () call @log(): () -> () + call @exp(): () -> () return }