diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -684,6 +684,7 @@ %x = math.powf %y, %z : tensor<4x?xbf16> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -63,7 +63,40 @@ return FloatAttr::get(getType(), log2(apf.convertToDouble())); if (ft.getWidth() == 32) - return FloatAttr::get(getType(), log2f(apf.convertToDouble())); + return FloatAttr::get(getType(), log2f(apf.convertToFloat())); + + return {}; +} + +//===----------------------------------------------------------------------===// +// PowFOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::PowFOp::fold(ArrayRef operands) { + auto ft = getType().dyn_cast(); + if (!ft) + return {}; + + APFloat vals[2]{APFloat(ft.getFloatSemantics()), + APFloat(ft.getFloatSemantics())}; + for (int i = 0; i < 2; ++i) { + if (!operands[i]) + return {}; + + auto attr = operands[i].dyn_cast(); + if (!attr) + return {}; + + vals[i] = attr.getValue(); + } + + if (ft.getWidth() == 64) + return FloatAttr::get( + getType(), pow(vals[0].convertToDouble(), vals[1].convertToDouble())); + + if (ft.getWidth() == 32) + return FloatAttr::get( + getType(), powf(vals[0].convertToFloat(), vals[1].convertToFloat())); return {}; } diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -73,3 +73,12 @@ %r = math.log2 %c : f64 return %r : f64 } + +// CHECK-LABEL: @powf_fold +// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @powf_fold() -> f32 { + %c = arith.constant 2.0 : f32 + %r = math.powf %c, %c : f32 + return %r : f32 +}