diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2339,6 +2339,7 @@ %x = powf %y, %z : tensor<4x?xbf16> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2263,6 +2263,33 @@ [](APInt a, APInt b) { return a | b; }); } +//===----------------------------------------------------------------------===// +// PowFOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowFOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary op takes two operands"); + if (operands[0] && operands[1] && + operands[0].getType() == operands[1].getType()) { + // make sure that types match (similar check performed in constFoldBinaryOp) + Type ty = getElementTypeOrSelf(operands[0]); + if (ty.isF16() || ty.isF32() || ty.isF64()) { + // following LLVM, only constant fold the pow of halfs, floats, and + // doubles + return constFoldBinaryOp(operands, [](APFloat a, APFloat b) { + bool unused; + // assume a and b are the same floating point type (i.e. share the same + // semantics) + APFloat res = APFloat(pow(FloatAttr::getValueAsDouble(a), + FloatAttr::getValueAsDouble(b))); + res.convert(a.getSemantics(), APFloat::rmNearestTiesToEven, &unused); + return res; + }); + } + } + return {}; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -797,3 +797,47 @@ %c = subview %arg0[] [] [] : memref to memref return %c : memref } + +// ----- + +// CHECK-LABEL: func @simple_powf +func @simple_powf() -> f32 { + %0 = constant 4.5 : f32 + %1 = constant 2.0 : f32 + + // CHECK-NEXT: [[C:%.+]] = constant 2.025{{0*}}e+01 : f32 + %2 = powf %0, %1 : f32 + + // CHECK-NEXT: return [[C]] + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: func @powf_splat_tensor +func @powf_splat_tensor() -> tensor<4xf32> { + %0 = constant dense<4.5> : tensor<4xf32> + %1 = constant dense<2.0> : tensor<4xf32> + + // CHECK-NEXT: [[C:%.+]] = constant dense<2.025{{0*}}e+01> : tensor<4xf32> + %2 = powf %0, %1 : tensor<4xf32> + + // CHECK-NEXT: return [[C]] + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func @powf_dont_fold_f128 +func @powf_dont_fold_f128() -> f128 { + // CHECK-NEXT: [[A:%.+]] = constant 4.5{{0*}}e+00 : f128 + %0 = constant 4.5 : f128 + // CHECK-NEXT: [[B:%.+]] = constant 2.0{{0*}}e+00 : f128 + %1 = constant 2.0 : f128 + + // CHECK-NEXT: %0 = powf [[A]], [[B]] : f128 + %2 = powf %0, %1 : f128 + + // CHECK-NEXT: return %0 + return %2 : f128 +}