diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -258,29 +258,30 @@ #define LN2_VALUE \ 0.693147180559945309417232121458176568075500134360255254120680009493393621L -#define LN2E_VALUE \ +#define LOG2E_VALUE \ 1.442695040888963407359924681001892137426645954152985934135449406931109219L //----------------------------------------------------------------------------// -// LogOp approximation. +// LogOp and Log2Op approximation. //----------------------------------------------------------------------------// namespace { +template +struct LogApproximationBase : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -// This approximations comes from the Julien Pommier's SSE math library. -// Link: http://gruntthepeon.free.fr/ssemath -struct LogApproximation : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(math::LogOp op, - PatternRewriter &rewriter) const final; + /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise. + LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter, + bool base2) const; }; } // namespace +// This approximation comes from Julien Pommier's SSE math library. +// Link: http://gruntthepeon.free.fr/ssemath +template LogicalResult -LogApproximation::matchAndRewrite(math::LogOp op, - PatternRewriter &rewriter) const { +LogApproximationBase::logMatchAndRewrite(Op op, PatternRewriter &rewriter, + bool base2) const { auto width = vectorWidth(op.operand().getType(), isF32); if (!width.hasValue()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); @@ -356,8 +357,13 @@ y0 = builder.create(cstNegHalf, x2, y0); x = builder.create(x, y0); - Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); - x = builder.create(e, cstLn2, x); + if (base2) { + Value cstLog2e = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); + x = builder.create(x, cstLog2e, e); + } else { + Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); + x = builder.create(e, cstLn2, x); + } Value invalidMask = builder.create(CmpFPredicate::ULT, op.operand(), cstZero); @@ -381,6 +387,28 @@ return success(); } +namespace { +struct LogApproximation : public LogApproximationBase { + using LogApproximationBase::LogApproximationBase; + + LogicalResult matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const final { + return logMatchAndRewrite(op, rewriter, /*base2=*/false); + } +}; +} // namespace + +namespace { +struct Log2Approximation : public LogApproximationBase { + using LogApproximationBase::LogApproximationBase; + + LogicalResult matchAndRewrite(math::Log2Op op, + PatternRewriter &rewriter) const final { + return logMatchAndRewrite(op, rewriter, /*base2=*/true); + } +}; +} // namespace + //----------------------------------------------------------------------------// // Exp approximation. //----------------------------------------------------------------------------// @@ -424,7 +452,7 @@ auto floor = [&](Value a) { return builder.create(a); }; Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); - Value cstLN2E = bcast(f32Cst(builder, static_cast(LN2E_VALUE))); + Value cstLog2E = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); // Polynomial coefficients. Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); @@ -437,7 +465,7 @@ Value x = op.operand(); // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) - Value xL2Inv = mul(x, cstLN2E); + Value xL2Inv = mul(x, cstLog2E); Value kF32 = floor(xL2Inv); Value kLn2 = mul(kF32, cstLn2); Value y = sub(x, kLn2); @@ -501,5 +529,6 @@ void mlir::populateMathPolynomialApproximationPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -9,7 +9,8 @@ %0 = math.tanh %arg0 : f32 // CHECK-NOT: log %1 = math.log %0 : f32 - return %1 : f32 + %2 = math.log2 %1 : f32 + return %2 : f32 } // CHECK-LABEL: @vector @@ -18,7 +19,8 @@ %0 = math.tanh %arg0 : vector<8xf32> // CHECK-NOT: log %1 = math.log %0 : vector<8xf32> - return %1 : vector<8xf32> + %2 = math.log2 %1 : vector<8xf32> + return %2 : vector<8xf32> } // CHECK-LABEL: @exp_scalar diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -71,8 +71,47 @@ return } +func @log2() { + // CHECK: 3.81887 + %0 = constant 14.112233 : f32 + %1 = math.log2 %0 : f32 + vector.print %1 : f32 + + // CHECK: -2, -0.415037, 0, 0.321928 + %2 = constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32> + %3 = math.log2 %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: -3.32193, -2.32193, -1.73697, -1.32193, -1, -0.736966, -0.514573, -0.321928 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.log2 %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -inf + %zero = constant 0.0 : f32 + %log_zero = math.log2 %zero : f32 + vector.print %log_zero : f32 + + // CHECK: nan + %neg_one = constant -1.0 : f32 + %log_neg_one = math.log2 %neg_one : f32 + vector.print %log_neg_one : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %log_inf = math.log2 %inf : f32 + vector.print %log_inf : f32 + + // CHECK: -inf, nan, inf, 1.58496 + %special_vec = constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32> + %log_special_vec = math.log2 %special_vec : vector<4xf32> + vector.print %log_special_vec : vector<4xf32> + + return +} + // -------------------------------------------------------------------------- // -// Log. +// Exp. // -------------------------------------------------------------------------- // func @exp() { // CHECK: 2.71828 @@ -111,6 +150,7 @@ func @main() { call @tanh(): () -> () call @log(): () -> () + call @log2(): () -> () call @exp(): () -> () return }