diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -168,11 +168,26 @@ Value operandA = op.getOperand(0); Value operandB = op.getOperand(1); Type opType = operandA.getType(); + Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); + Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter); + Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); + Value opASquared = b.create(opType, operandA, operandA); + Value opBHalf = b.create(opType, operandB, two); - Value logA = b.create(opType, operandA); - Value mult = b.create(opType, logA, operandB); + Value logA = b.create(opType, opASquared); + Value mult = b.create(opType, opBHalf, logA); Value expResult = b.create(opType, mult); - rewriter.replaceOp(op, expResult); + Value negExpResult = b.create(opType, expResult, negOne); + Value remainder = b.create(opType, operandB, two); + Value negCheck = + b.create(arith::CmpFPredicate::OLT, operandA, zero); + Value oddPower = + b.create(arith::CmpFPredicate::ONE, remainder, zero); + Value oddAndNeg = b.create(op->getLoc(), oddPower, negCheck); + + Value res = b.create(op->getLoc(), oddAndNeg, negExpResult, + expResult); + rewriter.replaceOp(op, res); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -222,10 +222,21 @@ // CHECK-LABEL: func @powf_func // CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64) func.func @powf_func(%a: f64, %b: f64) ->f64 { - // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]] - // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]] + // CHECK-DAG = [[CST0:%.+]] = arith.constant 0.000000e+00 + // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00 + // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00 + // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]] + // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] + // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]] + // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]] // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]] - // CHECK: return [[EXPR]] + // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]] + // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]] + // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]] + // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]] + // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]] + // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]] + // CHECK: return [[SEL]] %ret = math.powf %a, %b : f64 return %ret : f64 } diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir --- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir @@ -195,7 +195,7 @@ %a_p = arith.constant 2.0 : f64 call @func_powff64(%a, %a_p) : (f64, f64) -> () - // CHECK-NEXT: nan + // CHECK-NEXT: -27 %b = arith.constant -3.0 : f64 %b_p = arith.constant 3.0 : f64 call @func_powff64(%b, %b_p) : (f64, f64) -> () @@ -220,16 +220,9 @@ %f_p = arith.constant 1.2 : f64 call @func_powff64(%f, %f_p) : (f64, f64) -> () - // CHECK-NEXT: nan - %g = arith.constant 0xff80000000000000 : f64 - call @func_powff64(%g, %g) : (f64, f64) -> () - - // CHECK-NEXT: nan - %h = arith.constant 0x7fffffffffffffff : f64 - call @func_powff64(%h, %h) : (f64, f64) -> () - // CHECK-NEXT: nan %i = arith.constant 1.0 : f64 + %h = arith.constant 0x7fffffffffffffff : f64 call @func_powff64(%i, %h) : (f64, f64) -> () // CHECK-NEXT: inf