diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -263,7 +263,12 @@ OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); }); + [](const APInt &lhs, const APInt &rhs) -> std::optional { + // Don't fold division by zero. + if (rhs.isZero()) + return std::nullopt; + return lhs.srem(rhs); + }); } //===----------------------------------------------------------------------===// @@ -273,7 +278,12 @@ OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); }); + [](const APInt &lhs, const APInt &rhs) -> std::optional { + // Don't fold division by zero. + if (rhs.isZero()) + return std::nullopt; + return lhs.urem(rhs); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -198,6 +198,24 @@ return %0 : index } +// CHECK-LABEL: @rems_zerodiv_nofold +func.func @rems_zerodiv_nofold() -> index { + %lhs = index.constant 2 + %rhs = index.constant 0 + // CHECK: index.rems + %0 = index.rems %lhs, %rhs + return %0 : index +} + +// CHECK-LABEL: @remu_zerodiv_nofold +func.func @remu_zerodiv_nofold() -> index { + %lhs = index.constant 2 + %rhs = index.constant 0 + // CHECK: index.remu + %0 = index.remu %lhs, %rhs + return %0 : index +} + // CHECK-LABEL: @rems func.func @rems() -> index { %lhs = index.constant -5