diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -20,6 +20,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns); void populateExpandCeilFPattern(RewritePatternSet &patterns); void populateExpandExp2FPattern(RewritePatternSet &patterns); +void populateExpandPowFPattern(RewritePatternSet &patterns); void populateExpandRoundFPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -157,6 +157,19 @@ rewriter.replaceOp(op, ret); return success(); } +// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) +static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operandA = op.getOperand(0); + Value operandB = op.getOperand(1); + Type opType = operandA.getType(); + + Value logA = b.create(opType, operandA); + Value mult = b.create(opType, logA, operandB); + Value expResult = b.create(opType, mult); + rewriter.replaceOp(op, expResult); + return success(); +} // exp2f(float x) -> exp(x * ln(2)) // Proof: Let's say 2^x = y @@ -264,6 +277,10 @@ patterns.add(convertExp2fOp); } +void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { + patterns.add(convertPowfOp); +} + void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { patterns.add(convertRoundOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -207,3 +207,16 @@ %ret = math.round %a : f64 return %ret : f64 } + +// ----- + +// CHECK-LABEL: func @powf_func +// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) +func.func @powf_func(%a: f64, %b: f64) ->f64 { + // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]] + // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]] + // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]] + // CHECK: return [[EXPR]] + %ret = math.powf %a, %b : f64 + return %ret : f64 +} diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -43,6 +43,7 @@ populateExpandFmaFPattern(patterns); populateExpandFloorFPattern(patterns); populateExpandCeilFPattern(patterns); + populateExpandPowFPattern(patterns); populateExpandRoundFPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -100,9 +100,68 @@ return } +// -------------------------------------------------------------------------- // +// pow. +// -------------------------------------------------------------------------- // +func.func @func_powff64(%a : f64, %b : f64) { + %r = math.powf %a, %b : f64 + vector.print %r : f64 + return +} + +func.func @powf() { + // CHECK: 16 + %a = arith.constant 4.0 : f64 + %a_p = arith.constant 2.0 : f64 + call @func_powff64(%a, %a_p) : (f64, f64) -> () + + // CHECK: -nan + %b = arith.constant -3.0 : f64 + %b_p = arith.constant 3.0 : f64 + call @func_powff64(%b, %b_p) : (f64, f64) -> () + + // CHECK: 2.343 + %c = arith.constant 2.343 : f64 + %c_p = arith.constant 1.000 : f64 + call @func_powff64(%c, %c_p) : (f64, f64) -> () + + // CHECK: 0.176171 + %d = arith.constant 4.25 : f64 + %d_p = arith.constant -1.2 : f64 + call @func_powff64(%d, %d_p) : (f64, f64) -> () + + // CHECK: 1 + %e = arith.constant 4.385 : f64 + %e_p = arith.constant 0.00 : f64 + call @func_powff64(%e, %e_p) : (f64, f64) -> () + + // CHECK: 6.62637 + %f = arith.constant 4.835 : f64 + %f_p = arith.constant 1.2 : f64 + call @func_powff64(%f, %f_p) : (f64, f64) -> () + + // CHECK: -nan + %g = arith.constant 0xff80000000000000 : f64 + call @func_powff64(%g, %g) : (f64, f64) -> () + + // CHECK: nan + %h = arith.constant 0x7fffffffffffffff : f64 + call @func_powff64(%h, %h) : (f64, f64) -> () + + // CHECK: nan + %i = arith.constant 1.0 : f64 + call @func_powff64(%i, %h) : (f64, f64) -> () + + // CHECK: inf + %j = arith.constant 29385.0 : f64 + %j_p = arith.constant 23598.0 : f64 + call @func_powff64(%j, %j_p) : (f64, f64) -> () + return +} func.func @main() { call @exp2f() : () -> () call @roundf() : () -> () + call @powf() : () -> () return }