diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -262,6 +262,10 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::DivUIOp::fold(ArrayRef operands) { + // divui (x, 1) -> x. + if (matchPattern(getRhs(), m_One())) + return getLhs(); + // Don't fold if it would require a division by zero. bool div0 = false; auto result = @@ -273,15 +277,6 @@ return a.udiv(b); }); - // Fold out division by one. Assumes all tensors of all ones are splats. - if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) - return getLhs(); - } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) - return getLhs(); - } - return div0 ? Attribute() : result; } @@ -290,6 +285,10 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::DivSIOp::fold(ArrayRef operands) { + // divsi (x, 1) -> x. + if (matchPattern(getRhs(), m_One())) + return getLhs(); + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = @@ -301,15 +300,6 @@ return a.sdiv_ov(b, overflowOrDiv0); }); - // Fold out division by one. Assumes all tensors of all ones are splats. - if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) - return getLhs(); - } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) - return getLhs(); - } - return overflowOrDiv0 ? Attribute() : result; } @@ -330,6 +320,10 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivUIOp::fold(ArrayRef operands) { + // ceildivui (x, 1) -> x. + if (matchPattern(getRhs(), m_One())) + return getLhs(); + bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { @@ -343,15 +337,6 @@ APInt one(a.getBitWidth(), 1, true); return quotient.uadd_ov(one, overflowOrDiv0); }); - // Fold out ceil division by one. Assumes all tensors of all ones are - // splats. - if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) - return getLhs(); - } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) - return getLhs(); - } return overflowOrDiv0 ? Attribute() : result; } @@ -361,6 +346,10 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivSIOp::fold(ArrayRef operands) { + // ceildivsi (x, 1) -> x. + if (matchPattern(getRhs(), m_One())) + return getLhs(); + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = @@ -398,16 +387,6 @@ return zero.ssub_ov(div, overflowOrDiv0); }); - // Fold out ceil division by one. Assumes all tensors of all ones are - // splats. - if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) - return getLhs(); - } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) - return getLhs(); - } - return overflowOrDiv0 ? Attribute() : result; } @@ -416,6 +395,10 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::FloorDivSIOp::fold(ArrayRef operands) { + // floordivsi (x, 1) -> x. + if (matchPattern(getRhs(), m_One())) + return getLhs(); + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = @@ -453,16 +436,6 @@ return zero.ssub_ov(ceil, overflowOrDiv0); }); - // Fold out floor division by one. Assumes all tensors of all ones are - // splats. - if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getValue() == 1) - return getLhs(); - } else if (auto rhs = operands[1].dyn_cast_or_null()) { - if (rhs.getSplatValue().getValue() == 1) - return getLhs(); - } - return overflowOrDiv0 ? Attribute() : result; }