diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -89,6 +89,7 @@ %x = math.abs %y : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -230,6 +231,7 @@ %x = math.copysign %y, %z : tensor<4x?xf8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -320,6 +322,7 @@ %x = math.ctlz %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -344,6 +347,7 @@ %x = math.cttz %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -368,6 +372,7 @@ %x = math.ctpop %y : tensor<4x?xi8> ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -20,6 +20,32 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" +//===----------------------------------------------------------------------===// +// AbsOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::AbsOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + auto ft = getType().cast(); + + APFloat apf = attr.getValue(); + + if (ft.getWidth() == 64) + return FloatAttr::get(getType(), fabs(apf.convertToDouble())); + + if (ft.getWidth() == 32) + return FloatAttr::get(getType(), fabsf(apf.convertToFloat())); + + return {}; +} + //===----------------------------------------------------------------------===// // CeilOp folder //===----------------------------------------------------------------------===// @@ -39,6 +65,81 @@ return FloatAttr::get(getType(), sourceVal); } +//===----------------------------------------------------------------------===// +// CopySignOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CopySignOp::fold(ArrayRef operands) { + auto ft = getType().dyn_cast(); + if (!ft) + return {}; + + APFloat vals[2]{APFloat(ft.getFloatSemantics()), + APFloat(ft.getFloatSemantics())}; + for (int i = 0; i < 2; ++i) { + if (!operands[i]) + return {}; + + auto attr = operands[i].dyn_cast(); + if (!attr) + return {}; + + vals[i] = attr.getValue(); + } + + vals[0].copySign(vals[1]); + + return FloatAttr::get(getType(), vals[0]); +} + +//===----------------------------------------------------------------------===// +// CountLeadingZerosOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros()); +} + +//===----------------------------------------------------------------------===// +// CountTrailingZerosOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros()); +} + +//===----------------------------------------------------------------------===// +// CtPopOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::CtPopOp::fold(ArrayRef operands) { + auto constOperand = operands.front(); + if (!constOperand) + return {}; + + auto attr = constOperand.dyn_cast(); + if (!attr) + return {}; + + return IntegerAttr::get(getType(), attr.getValue().countPopulation()); +} + //===----------------------------------------------------------------------===// // Log2Op folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -91,3 +91,58 @@ %r = math.sqrt %c : f32 return %r : f32 } + +// CHECK-LABEL: @abs_fold +// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @abs_fold() -> f32 { + %c = arith.constant -4.0 : f32 + %r = math.abs %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @copysign_fold +// CHECK: %[[cst:.+]] = arith.constant -4.000000e+00 : f32 +// CHECK: return %[[cst]] +func @copysign_fold() -> f32 { + %c1 = arith.constant 4.0 : f32 + %c2 = arith.constant -9.0 : f32 + %r = math.copysign %c1, %c2 : f32 + return %r : f32 +} + +// CHECK-LABEL: @ctlz_fold1 +// CHECK: %[[cst:.+]] = arith.constant 31 : i32 +// CHECK: return %[[cst]] +func @ctlz_fold1() -> i32 { + %c = arith.constant 1 : i32 + %r = math.ctlz %c : i32 + return %r : i32 +} + +// CHECK-LABEL: @ctlz_fold2 +// CHECK: %[[cst:.+]] = arith.constant 7 : i8 +// CHECK: return %[[cst]] +func @ctlz_fold2() -> i8 { + %c = arith.constant 1 : i8 + %r = math.ctlz %c : i8 + return %r : i8 +} + +// CHECK-LABEL: @cttz_fold +// CHECK: %[[cst:.+]] = arith.constant 8 : i32 +// CHECK: return %[[cst]] +func @cttz_fold() -> i32 { + %c = arith.constant 256 : i32 + %r = math.cttz %c : i32 + return %r : i32 +} + +// CHECK-LABEL: @ctpop_fold +// CHECK: %[[cst:.+]] = arith.constant 13 : i32 +// CHECK: return %[[cst]] +func @ctpop_fold() -> i32 { + %c = arith.constant 1151756898 : i32 + %r = math.ctpop %c : i32 + return %r : i32 +}