diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -413,6 +413,53 @@ }; } // namespace +//----------------------------------------------------------------------------// +// Log1p approximation. +//----------------------------------------------------------------------------// + +namespace { +struct Log1pApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::Log1pOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +// Approximate log(1+x). +LogicalResult +Log1pApproximation::matchAndRewrite(math::Log1pOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + // Approximate log(1+x) using the following, due to W. Kahan: + // u = x + 1.0; + // if (u == 1.0 || u == inf) return x; + // return x * log(u) / (u - 1.0); + // ^^^^^^^^^^^^^^^^^^^^^^ + // "logLarge" below. + Value cstOne = bcast(f32Cst(builder, 1.0f)); + Value x = op.operand(); + Value u = builder.create(x, cstOne); + Value uSmall = builder.create(CmpFPredicate::OEQ, u, cstOne); + Value logU = builder.create(u); + Value uInf = builder.create(CmpFPredicate::OEQ, u, logU); + Value logLarge = builder.create( + x, builder.create(logU, builder.create(u, cstOne))); + Value approximation = builder.create( + builder.create(uSmall, uInf), x, logLarge); + rewriter.replaceOp(op, approximation); + return success(); +} + //----------------------------------------------------------------------------// // Exp approximation. //----------------------------------------------------------------------------// @@ -534,5 +581,5 @@ void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + Log1pApproximation, ExpApproximation>(patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -10,7 +10,8 @@ // CHECK-NOT: log %1 = math.log %0 : f32 %2 = math.log2 %1 : f32 - return %2 : f32 + %3 = math.log1p %2 : f32 + return %3 : f32 } // CHECK-LABEL: @vector @@ -20,7 +21,8 @@ // CHECK-NOT: log %1 = math.log %0 : vector<8xf32> %2 = math.log2 %1 : vector<8xf32> - return %2 : vector<8xf32> + %3 = math.log1p %2 : vector<8xf32> + return %3 : vector<8xf32> } // CHECK-LABEL: @exp_scalar diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -110,6 +110,45 @@ return } +func @log1p() { + // CHECK: 0.00995033 + %0 = constant 0.01 : f32 + %1 = math.log1p %0 : f32 + vector.print %1 : f32 + + // CHECK: -4.60517, -0.693147, 0, 1.38629 + %2 = constant dense<[-0.99, -0.5, 0.0, 3.0]> : vector<4xf32> + %3 = math.log1p %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: 0.0953102, 0.182322, 0.262364, 0.336472, 0.405465, 0.470004, 0.530628, 0.587787 + %4 = constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32> + %5 = math.log1p %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -inf + %neg_one = constant -1.0 : f32 + %log_neg_one = math.log1p %neg_one : f32 + vector.print %log_neg_one : f32 + + // CHECK: nan + %neg_two = constant -2.0 : f32 + %log_neg_two = math.log1p %neg_two : f32 + vector.print %log_neg_two : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %log_inf = math.log1p %inf : f32 + vector.print %log_inf : f32 + + // CHECK: -inf, nan, inf, 9.99995e-06 + %special_vec = constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32> + %log_special_vec = math.log1p %special_vec : vector<4xf32> + vector.print %log_special_vec : vector<4xf32> + + return +} + // -------------------------------------------------------------------------- // // Exp. // -------------------------------------------------------------------------- // @@ -151,6 +190,7 @@ call @tanh(): () -> () call @log(): () -> () call @log2(): () -> () + call @log1p(): () -> () call @exp(): () -> () return }