diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -426,6 +426,9 @@ /// and(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); + /// and(x, 1) -> x + if (matchPattern(rhs(), m_One())) + return lhs(); /// and(x,x) -> x if (lhs() == rhs()) return rhs(); @@ -2245,6 +2248,9 @@ // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); + // subi(x,0) -> x + if (matchPattern(rhs(), m_Zero())) + return lhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -92,6 +92,16 @@ // ----- +// CHECK: func @simple_and([[ARG0:%[a-zA-Z0-9]+]]: i1) +func @simple_and(%arg0 : i1) -> i1 { + %0 = constant 1 : i1 + %1 = and %arg0, %0 : i1 + // CHECK: return [[ARG0]] + return %1 : i1 +} + +// ----- + // CHECK-LABEL: func @addi_splat_vector func @addi_splat_vector() -> vector<8xi32> { %0 = constant dense<1> : vector<8xi32> @@ -134,16 +144,19 @@ // ----- -// CHECK-LABEL: func @simple_subi -func @simple_subi() -> i32 { +// CHECK: func @simple_subi +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +func @simple_subi(%arg0 : i32) -> (i32, i32) { %0 = constant 4 : i32 %1 = constant 1 : i32 + %2 = constant 0 : i32 // CHECK-NEXT:[[C3:%.+]] = constant 3 : i32 - %2 = subi %0, %1 : i32 + %3 = subi %0, %1 : i32 + %4 = subi %arg0, %2 : i32 - // CHECK-NEXT: return [[C3]] - return %2 : i32 + // CHECK-NEXT: return [[C3]], [[ARG0]] + return %3, %4 : i32, i32 } // -----