diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -576,10 +576,65 @@ return success(); } +//----------------------------------------------------------------------------// +// ExpM1 approximation. +//----------------------------------------------------------------------------// + +namespace { + +struct ExpM1Approximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::ExpM1Op op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +ExpM1Approximation::matchAndRewrite(math::ExpM1Op op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + // expm1(x) = exp(x) - 1 = u - 1. + // We have to handle it carefully when x is near 0, i.e. u ~= 1, + // and when the input is ~= -inf, i.e. u - 1 ~= -1. + Value cstOne = bcast(f32Cst(builder, 1.0f)); + Value cstNegOne = bcast(f32Cst(builder, -1.0f)); + Value x = op.operand(); + Value u = builder.create(x); + Value uEqOne = builder.create(CmpFPredicate::OEQ, u, cstOne); + Value uMinusOne = builder.create(u, cstOne); + Value uMinusOneEqNegOne = + builder.create(CmpFPredicate::OEQ, uMinusOne, cstNegOne); + // logU = log(u) ~= x + Value logU = builder.create(u); + + // Detect exp(x) = +inf; written this way to avoid having to form +inf. + Value isInf = builder.create(CmpFPredicate::OEQ, logU, u); + + // (u - 1) * (x / ~x) + Value expm1 = + builder.create(uMinusOne, builder.create(x, logU)); + expm1 = builder.create(isInf, u, expm1); + Value approximation = builder.create( + uEqOne, x, builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); + rewriter.replaceOp(op, approximation); + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + Log1pApproximation, ExpApproximation, ExpM1Approximation>( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -11,7 +11,10 @@ %1 = math.log %0 : f32 %2 = math.log2 %1 : f32 %3 = math.log1p %2 : f32 - return %3 : f32 + // CHECK-NOT: exp + %4 = math.exp %3 : f32 + %5 = math.expm1 %4 : f32 + return %5 : f32 } // CHECK-LABEL: @vector @@ -22,18 +25,8 @@ %1 = math.log %0 : vector<8xf32> %2 = math.log2 %1 : vector<8xf32> %3 = math.log1p %2 : vector<8xf32> - return %3 : vector<8xf32> -} - -// CHECK-LABEL: @exp_scalar -func @exp_scalar(%arg0: f32) -> f32 { - %0 = math.exp %arg0 : f32 - return %0 : f32 -} - -// CHECK-LABEL: @exp_vector -func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { - // CHECK-NOT: math.exp - %0 = math.exp %arg0 : vector<8xf32> - return %0 : vector<8xf32> + // CHECK-NOT: exp + %4 = math.exp %3 : vector<8xf32> + %5 = math.expm1 %4 : vector<8xf32> + return %5 : vector<8xf32> } diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -186,11 +186,46 @@ return } +func @expm1() { + // CHECK: 1e-10 + %0 = constant 1.0e-10 : f32 + %1 = math.expm1 %0 : f32 + vector.print %1 : f32 + + // CHECK: -0.00995016, 0.0100502, 0.648721, 6.38905 + %2 = constant dense<[-0.01, 0.01, 0.5, 2.0]> : vector<4xf32> + %3 = math.expm1 %2 : vector<4xf32> + vector.print %3 : vector<4xf32> + + // CHECK: -0.181269, 0, 0.221403, 0.491825, 0.822119, 1.22554, 1.71828, 2.32012 + %4 = constant dense<[-0.2, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2]> : vector<8xf32> + %5 = math.expm1 %4 : vector<8xf32> + vector.print %5 : vector<8xf32> + + // CHECK: -1 + %neg_inf = constant 0xff800000 : f32 + %expm1_neg_inf = math.expm1 %neg_inf : f32 + vector.print %expm1_neg_inf : f32 + + // CHECK: inf + %inf = constant 0x7f800000 : f32 + %expm1_inf = math.expm1 %inf : f32 + vector.print %expm1_inf : f32 + + // CHECK: -1, inf, 1e-10 + %special_vec = constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32> + %log_special_vec = math.expm1 %special_vec : vector<3xf32> + vector.print %log_special_vec : vector<3xf32> + + return +} + func @main() { call @tanh(): () -> () call @log(): () -> () call @log2(): () -> () call @log1p(): () -> () call @exp(): () -> () + call @expm1(): () -> () return }