diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td --- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -30,5 +30,6 @@ ``` }]; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // MATH_BASE diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -25,8 +25,8 @@ // AbsFOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::AbsFOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, +OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp(adaptor.getOperands(), [](const APFloat &a) { return abs(a); }); } @@ -34,8 +34,8 @@ // AbsIOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::AbsIOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, +OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp(adaptor.getOperands(), [](const APInt &a) { return a.abs(); }); } @@ -43,9 +43,9 @@ // AtanOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::AtanOp::fold(ArrayRef operands) { +OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(atan(a.convertToDouble())); @@ -61,9 +61,10 @@ // Atan2Op folder //===----------------------------------------------------------------------===// -OpFoldResult math::Atan2Op::fold(ArrayRef operands) { +OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) { return constFoldBinaryOpConditional( - operands, [](const APFloat &a, const APFloat &b) -> Optional { + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) -> Optional { if (a.isZero() && b.isZero()) return llvm::APFloat::getNaN(a.getSemantics()); @@ -83,20 +84,21 @@ // CeilOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CeilOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APFloat &a) { - APFloat result(a); - result.roundToIntegral(llvm::RoundingMode::TowardPositive); - return result; - }); +OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APFloat &a) { + APFloat result(a); + result.roundToIntegral(llvm::RoundingMode::TowardPositive); + return result; + }); } //===----------------------------------------------------------------------===// // CopySignOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CopySignOp::fold(ArrayRef operands) { - return constFoldBinaryOp(operands, +OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) { + return constFoldBinaryOp(adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { APFloat result(a); result.copySign(b); @@ -108,9 +110,9 @@ // CosOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CosOp::fold(ArrayRef operands) { +OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(cos(a.convertToDouble())); @@ -126,9 +128,9 @@ // SinOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::SinOp::fold(ArrayRef operands) { +OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(sin(a.convertToDouble())); @@ -144,39 +146,42 @@ // CountLeadingZerosOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APInt &a) { - return APInt(a.getBitWidth(), a.countLeadingZeros()); - }); +OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APInt &a) { + return APInt(a.getBitWidth(), a.countLeadingZeros()); + }); } //===----------------------------------------------------------------------===// // CountTrailingZerosOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APInt &a) { - return APInt(a.getBitWidth(), a.countTrailingZeros()); - }); +OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APInt &a) { + return APInt(a.getBitWidth(), a.countTrailingZeros()); + }); } //===----------------------------------------------------------------------===// // CtPopOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::CtPopOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APInt &a) { - return APInt(a.getBitWidth(), a.countPopulation()); - }); +OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APInt &a) { + return APInt(a.getBitWidth(), a.countPopulation()); + }); } //===----------------------------------------------------------------------===// // ErfOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::ErfOp::fold(ArrayRef operands) { +OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(erf(a.convertToDouble())); @@ -192,9 +197,10 @@ // IPowIOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::IPowIOp::fold(ArrayRef operands) { +OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) { return constFoldBinaryOpConditional( - operands, [](const APInt &base, const APInt &power) -> Optional { + adaptor.getOperands(), + [](const APInt &base, const APInt &power) -> Optional { unsigned width = base.getBitWidth(); auto zeroValue = APInt::getZero(width); APInt oneValue{width, 1ULL, /*isSigned=*/true}; @@ -242,9 +248,9 @@ // LogOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::LogOp::fold(ArrayRef operands) { +OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; @@ -262,9 +268,9 @@ // Log2Op folder //===----------------------------------------------------------------------===// -OpFoldResult math::Log2Op::fold(ArrayRef operands) { +OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; @@ -282,9 +288,9 @@ // Log10Op folder //===----------------------------------------------------------------------===// -OpFoldResult math::Log10Op::fold(ArrayRef operands) { +OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; @@ -303,9 +309,9 @@ // Log1pOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::Log1pOp::fold(ArrayRef operands) { +OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: if ((a + APFloat(1.0)).isNegative()) @@ -325,9 +331,10 @@ // PowFOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::PowFOp::fold(ArrayRef operands) { +OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) { return constFoldBinaryOpConditional( - operands, [](const APFloat &a, const APFloat &b) -> Optional { + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) -> Optional { if (a.getSizeInBits(a.getSemantics()) == 64 && b.getSizeInBits(b.getSemantics()) == 64) return APFloat(pow(a.convertToDouble(), b.convertToDouble())); @@ -344,9 +351,9 @@ // SqrtOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::SqrtOp::fold(ArrayRef operands) { +OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; @@ -365,9 +372,9 @@ // ExpOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::ExpOp::fold(ArrayRef operands) { +OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(exp(a.convertToDouble())); @@ -383,9 +390,9 @@ // Exp2Op folder //===----------------------------------------------------------------------===// -OpFoldResult math::Exp2Op::fold(ArrayRef operands) { +OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(exp2(a.convertToDouble())); @@ -401,9 +408,9 @@ // ExpM1Op folder //===----------------------------------------------------------------------===// -OpFoldResult math::ExpM1Op::fold(ArrayRef operands) { +OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(expm1(a.convertToDouble())); @@ -419,9 +426,9 @@ // TanOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::TanOp::fold(ArrayRef operands) { +OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(tan(a.convertToDouble())); @@ -437,9 +444,9 @@ // TanhOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::TanhOp::fold(ArrayRef operands) { +OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(tanh(a.convertToDouble())); @@ -455,33 +462,35 @@ // RoundEvenOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::RoundEvenOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APFloat &a) { - APFloat result(a); - result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); - return result; - }); +OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APFloat &a) { + APFloat result(a); + result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); + return result; + }); } //===----------------------------------------------------------------------===// // FloorOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::FloorOp::fold(ArrayRef operands) { - return constFoldUnaryOp(operands, [](const APFloat &a) { - APFloat result(a); - result.roundToIntegral(llvm::RoundingMode::TowardNegative); - return result; - }); +OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOp( + adaptor.getOperands(), [](const APFloat &a) { + APFloat result(a); + result.roundToIntegral(llvm::RoundingMode::TowardNegative); + return result; + }); } //===----------------------------------------------------------------------===// // RoundOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::RoundOp::fold(ArrayRef operands) { +OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(round(a.convertToDouble())); @@ -497,9 +506,9 @@ // TruncOp folder //===----------------------------------------------------------------------===// -OpFoldResult math::TruncOp::fold(ArrayRef operands) { +OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { return constFoldUnaryOpConditional( - operands, [](const APFloat &a) -> Optional { + adaptor.getOperands(), [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(trunc(a.convertToDouble()));