diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" @@ -25,25 +26,10 @@ //===----------------------------------------------------------------------===// OpFoldResult math::AbsOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; - - auto ft = getType().cast(); - - APFloat apf = attr.getValue(); - - if (ft.getWidth() == 64) - return FloatAttr::get(getType(), fabs(apf.convertToDouble())); - - if (ft.getWidth() == 32) - return FloatAttr::get(getType(), fabsf(apf.convertToFloat())); - - return {}; + return constFoldUnaryOp(operands, [](const APFloat &a) { + APFloat result(a); + return abs(result); + }); } //===----------------------------------------------------------------------===// @@ -51,18 +37,11 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CeilOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; - - APFloat sourceVal = attr.getValue(); - sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive); - - return FloatAttr::get(getType(), sourceVal); + return constFoldUnaryOp(operands, [](const APFloat &a) { + APFloat result(a); + result.roundToIntegral(llvm::RoundingMode::TowardPositive); + return result; + }); } //===----------------------------------------------------------------------===// @@ -70,26 +49,12 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CopySignOp::fold(ArrayRef operands) { - auto ft = getType().dyn_cast(); - if (!ft) - return {}; - - APFloat vals[2]{APFloat(ft.getFloatSemantics()), - APFloat(ft.getFloatSemantics())}; - for (int i = 0; i < 2; ++i) { - if (!operands[i]) - return {}; - - auto attr = operands[i].dyn_cast(); - if (!attr) - return {}; - - vals[i] = attr.getValue(); - } - - vals[0].copySign(vals[1]); - - return FloatAttr::get(getType(), vals[0]); + return constFoldBinaryOp(operands, + [](const APFloat &a, const APFloat &b) { + APFloat result(a); + result.copySign(b); + return result; + }); } //===----------------------------------------------------------------------===// @@ -97,15 +62,9 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; - - return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros()); + return constFoldUnaryOp(operands, [](const APInt &a) { + return APInt(a.getBitWidth(), a.countLeadingZeros()); + }); } //===----------------------------------------------------------------------===// @@ -113,15 +72,9 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; - - return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros()); + return constFoldUnaryOp(operands, [](const APInt &a) { + return APInt(a.getBitWidth(), a.countTrailingZeros()); + }); } //===----------------------------------------------------------------------===// @@ -129,15 +82,9 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CtPopOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); - if (!constOperand) - return {}; - - auto attr = constOperand.dyn_cast(); - if (!attr) - return {}; - - return IntegerAttr::get(getType(), attr.getValue().countPopulation()); + return constFoldUnaryOp(operands, [](const APInt &a) { + return APInt(a.getBitWidth(), a.countPopulation()); + }); } //===----------------------------------------------------------------------===//