diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -19,6 +19,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); +void populateExpandExp2FPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -158,6 +158,22 @@ return success(); } +// exp2f(float x) -> exp(x * ln(2)) +// Proof: Let's say 2^x = y +// ln(2^x) = ln(y) +// x * ln(2) = ln(y) => e ^(x*ln(2)) = y +static LogicalResult convertExp2fOp(math::Exp2Op op, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type opType = operand.getType(); + Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); + Value mult = b.create(opType, operand, ln2); + Value exp = b.create(op->getLoc(), mult); + rewriter.replaceOp(op, exp); + return success(); +} + // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, @@ -222,6 +238,10 @@ patterns.add(convertCeilOp); } +void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { + patterns.add(convertExp2fOp); +} + void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -165,3 +165,27 @@ %ret = math.ceil %a : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func @exp2f_func +// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64 +func.func @exp2f_func(%a: f64) -> f64 { + // CHECK-DAG: [[CST:%.+]] = arith.constant 0.69314718055994529 + // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] + // CHECK: [[EXP:%.+]] = math.exp [[MULF]] + // CHECK: return [[EXP]] + %ret = math.exp2 %a : f64 + return %ret : f64 +} + +// CHECK-LABEL: func @exp2f_func_tensor +// CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32> +func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.693147182> + // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] + // CHECK: [[EXP:%.+]] = math.exp [[MULF]] + // CHECK: return [[EXP]] + %ret = math.exp2 %a : tensor<1xf32> + return %ret : tensor<1xf32> +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -37,6 +37,7 @@ void TestExpandMathPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateExpandCtlzPattern(patterns); + populateExpandExp2FPattern(patterns); populateExpandTanPattern(patterns); populateExpandTanhPattern(patterns); populateExpandFmaFPattern(patterns); diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_c_runner_utils \ +// RUN: -shared-libs=%mlir_runner_utils \ +// RUN: | FileCheck %s + +// -------------------------------------------------------------------------- // +// exp2f. +// -------------------------------------------------------------------------- // +func.func @func_exp2f(%a : f64) { + %r = math.exp2 %a : f64 + vector.print %r : f64 + return +} + +func.func @exp2f() { + // CHECK: 2 + %a = arith.constant 1.0 : f64 + call @func_exp2f(%a) : (f64) -> () + + // CHECK: 4 + %b = arith.constant 2.0 : f64 + call @func_exp2f(%b) : (f64) -> () + + // CHECK: 5.65685 + %c = arith.constant 2.5 : f64 + call @func_exp2f(%c) : (f64) -> () + + // CHECK: 0.29730 + %d = arith.constant -1.75 : f64 + call @func_exp2f(%d) : (f64) -> () + + // CHECK: 1.09581 + %e = arith.constant 0.132 : f64 + call @func_exp2f(%e) : (f64) -> () + + // CHECK: inf + %f1 = arith.constant 0.00 : f64 + %f2 = arith.constant 1.00 : f64 + %f = arith.divf %f2, %f1 : f64 + call @func_exp2f(%f) : (f64) -> () + + // CHECK: inf + %g = arith.constant 5038939.0 : f64 + call @func_exp2f(%g) : (f64) -> () + + // CHECK: 0 + %neg_inf = arith.constant 0xff80000000000000 : f64 + call @func_exp2f(%neg_inf) : (f64) -> () + + // CHECK: inf + %i = arith.constant 0x7fc0000000000000 : f64 + call @func_exp2f(%i) : (f64) -> () + return +} + +func.func @main() { + call @exp2f() : () -> () + return +}