diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -122,9 +122,18 @@ //===----------------------------------------------------------------------===// OpFoldResult AddOp::fold(FoldAdaptor adaptor) { - return foldBinaryOpUnchecked( - adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }); + if (OpFoldResult result = foldBinaryOpUnchecked( + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; })) + return result; + + if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { + // Fold `add(x, 0) -> x`. + if (rhs.getValue().isZero()) + return getLhs(); + } + + return {}; } //===----------------------------------------------------------------------===// @@ -132,9 +141,18 @@ //===----------------------------------------------------------------------===// OpFoldResult SubOp::fold(FoldAdaptor adaptor) { - return foldBinaryOpUnchecked( - adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }); + if (OpFoldResult result = foldBinaryOpUnchecked( + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; })) + return result; + + if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { + // Fold `sub(x, 0) -> x`. + if (rhs.getValue().isZero()) + return getLhs(); + } + + return {}; } //===----------------------------------------------------------------------===// @@ -144,8 +162,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { if (OpFoldResult result = foldBinaryOpUnchecked( adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }); - !result.isNull()) + [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; })) return result; if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -531,3 +531,19 @@ // CHECK: return %idx0, %arg0 return %0, %1 : index, index } + +// CHECK-LABEL: @add_identity +func.func @add_identity(%arg0: index) -> index { + %idx0 = index.constant 0 + %0 = index.add %arg0, %idx0 + // CHECK-NEXT: return %arg0 + return %0 : index +} + +// CHECK-LABEL: @sub_identity +func.func @sub_identity(%arg0: index) -> index { + %idx0 = index.constant 0 + %0 = index.sub %arg0, %idx0 + // CHECK-NEXT: return %arg0 + return %0 : index +}