diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -142,9 +142,22 @@ //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(FoldAdaptor adaptor) { - return foldBinaryOpUnchecked( - adaptor.getOperands(), - [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }); + if (OpFoldResult result = foldBinaryOpUnchecked( + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }); + !result.isNull()) + return result; + + if (auto rhs = dyn_cast_or_null(adaptor.getRhs())) { + // Fold `mul(x, 1) -> x`. + if (rhs.getValue().isOne()) + return getLhs(); + // Fold `mul(x, 0) -> 0`. + if (rhs.getValue().isZero()) + return rhs; + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -521,3 +521,13 @@ // CHECK: return %true, %false return %1, %2 : i1, i1 } + +// CHECK-LABEL: @mul_identity +func.func @mul_identity(%arg0: index) -> (index, index) { + %idx0 = index.constant 0 + %idx1 = index.constant 1 + %0 = index.mul %arg0, %idx0 + %1 = index.mul %arg0, %idx1 + // CHECK: return %idx0, %arg0 + return %0, %1 : index, index +}