diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -19,6 +19,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns); void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); +void populateExpandPowFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -157,6 +157,19 @@ rewriter.replaceOp(op, ret); return success(); } +// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operandA = op.getOperand(0); + Value operandB = op.getOperand(1); + Type opType = operandA.getType(); + + Value logA = b.create(opType, operandA); + Value mult = b.create(opType, logA, operandB); + Value expResult = b.create(opType, mult); + rewriter.replaceOp(op, expResult); + return success(); +} // Converts math.ctlz to scf and arith operations. This is done // by performing a binary search on the bits. @@ -222,6 +235,9 @@ patterns.add(convertCeilOp); } +void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { + patterns.add(convertPowfOp); +} void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { patterns.add(convertFloorOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -165,3 +165,16 @@ %ret = math.ceil %a : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func @powf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) +func.func @powf_func(%a: f64, %b: f64) ->f64 { + // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]] + // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]] + // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]] + // CHECK: return [[EXPR]] + %ret = math.powf %a, %b : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -42,6 +42,7 @@ populateExpandFmaFPattern(patterns); populateExpandFloorFPattern(patterns); populateExpandCeilFPattern(patterns); + populateExpandPowFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -684,6 +684,67 @@ return } +// -------------------------------------------------------------------------- // +// Pow. +// -------------------------------------------------------------------------- // +func.func @func_powff64(%a : f64, %b : f64) { + %r = math.powf %a, %b : f64 + vector.print %r : f64 + return +} + +func.func @powf() { + // CHECK: 16 + %a = arith.constant 4.0 : f64 + %a_p = arith.constant 2.0 : f64 + call @func_powff64(%a, %a_p) : (f64, f64) -> () + + // CHECK: -27 + %b = arith.constant -3.0 : f64 + %b_p = arith.constant 3.0 : f64 + call @func_powff64(%b, %b_p) : (f64, f64) -> () + + // CHECK: 2.343 + %c = arith.constant 2.343 : f64 + %c_p = arith.constant 1.000 : f64 + call @func_powff64(%c, %c_p) : (f64, f64) -> () + + // CHECK: 0.176171 + %d = arith.constant 4.25 : f64 + %d_p = arith.constant -1.2 : f64 + call @func_powff64(%d, %d_p) : (f64, f64) -> () + + // CHECK: 1 + %e = arith.constant 4.385 : f64 + %e_p = arith.constant 0.00 : f64 + call @func_powff64(%e, %e_p) : (f64, f64) -> () + + // CHECK: nan + %f = arith.constant -4.835 : f64 + %f_p = arith.constant 1.2 : f64 + call @func_powff64(%f, %f_p) : (f64, f64) -> () + + // CHECK: 0 + %g = arith.constant 0xff80000000000000 : f64 + call @func_powff64(%g, %g) : (f64, f64) -> () + + // CHECK: nan + %h = arith.constant 0x7fffffffffffffff : f64 + call @func_powff64(%h, %h) : (f64, f64) -> () + + // CHECK: 1 + %i = arith.constant 1.0 : f64 + call @func_powff64(%i, %h) : (f64, f64) -> () + + // CHECK: inf + %j = arith.constant 29385.0 : f64 + %j_p = arith.constant 23598.0 : f64 + call @func_powff64(%j, %j_p) : (f64, f64) -> () + return +} + +// + func.func @main() { call @tanh(): () -> () call @log(): () -> () @@ -699,6 +760,7 @@ call @cbrt() : () -> () call @floorf() : () -> () call @ceilf() : () -> () + call @powf() : () -> () return }