diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2339,6 +2339,7 @@ %x = powf %y, %z : tensor<4x?xbf16> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2263,6 +2263,33 @@ [](APInt a, APInt b) { return a | b; }); } +//===----------------------------------------------------------------------===// +// PowFOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary op takes two operands"); + if (operands[0] && operands[1] && + operands[0].getType() == operands[1].getType()) { + // make sure that types match (similar check performed in constFoldBinaryOp) + Type ty = getElementTypeOrSelf(operands[0]); + if (ty.isF16() || ty.isF32() || ty.isF64()) { + // following LLVM, only constant fold the pow of halfs, floats, and + // doubles + return constFoldBinaryOp(operands, [](APFloat a, APFloat b) { + bool unused; + // assume a and b are the same floating point type (i.e. share the same + // semantics) + APFloat res = APFloat(pow(FloatAttr::getValueAsDouble(a), + FloatAttr::getValueAsDouble(b))); + res.convert(a.getSemantics(), APFloat::rmNearestTiesToEven, &unused); + return res; + }); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===//