diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1150,6 +1150,25 @@ return getBoolAttribute(getType(), getContext(), val); } + if (matchPattern(getRhs(), m_Zero())) { + if (auto extOp = getLhs().getDefiningOp()) { + if (extOp.getOperand().getType().cast().getWidth() == 1) { + // extsi(%x : i1 -> iN) != 0 -> %x + if (getPredicate() == arith::CmpIPredicate::ne) { + return extOp.getOperand(); + } + } + } + if (auto extOp = getLhs().getDefiningOp()) { + if (extOp.getOperand().getType().cast().getWidth() == 1) { + // extui(%x : i1 -> iN) != 0 -> %x + if (getPredicate() == arith::CmpIPredicate::ne) { + return extOp.getOperand(); + } + } + } + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -50,6 +50,26 @@ // ----- +// CHECK-LABEL: @cmpOfExtSI +// CHECK-NEXT: return %arg0 +func @cmpOfExtSI(%arg0: i1) -> i1 { + %ext = arith.extsi %arg0 : i1 to i64 + %c0 = arith.constant 0 : i64 + %res = arith.cmpi ne, %ext, %c0 : i64 + return %res : i1 +} + +// CHECK-LABEL: @cmpOfExtUI +// CHECK-NEXT: return %arg0 +func @cmpOfExtUI(%arg0: i1) -> i1 { + %ext = arith.extui %arg0 : i1 to i64 + %c0 = arith.constant 0 : i64 + %res = arith.cmpi ne, %ext, %c0 : i64 + return %res : i1 +} + +// ----- + // CHECK-LABEL: @indexCastOfSignExtend // CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index // CHECK: return %[[res]]