diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1296,20 +1296,16 @@ if (matchPattern(getRhs(), m_Zero())) { if (auto extOp = getLhs().getDefiningOp()) { - if (extOp.getOperand().getType().cast().getWidth() == 1) { - // extsi(%x : i1 -> iN) != 0 -> %x - if (getPredicate() == arith::CmpIPredicate::ne) { - return extOp.getOperand(); - } - } + // extsi(%x : i1 -> iN) != 0 -> %x + if (extOp.getOperand().getType().cast().getWidth() == 1 && + getPredicate() == arith::CmpIPredicate::ne) + return extOp.getOperand(); } if (auto extOp = getLhs().getDefiningOp()) { - if (extOp.getOperand().getType().cast().getWidth() == 1) { - // extui(%x : i1 -> iN) != 0 -> %x - if (getPredicate() == arith::CmpIPredicate::ne) { - return extOp.getOperand(); - } - } + // extui(%x : i1 -> iN) != 0 -> %x + if (extOp.getOperand().getType().cast().getWidth() == 1 && + getPredicate() == arith::CmpIPredicate::ne) + return extOp.getOperand(); } } @@ -1733,24 +1729,24 @@ return failure(); // select %x, c1, %c0 => extui %arg - if (matchPattern(op.getTrueValue(), m_One())) - if (matchPattern(op.getFalseValue(), m_Zero())) { - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getCondition()); - return success(); - } + if (matchPattern(op.getTrueValue(), m_One()) && + matchPattern(op.getFalseValue(), m_Zero())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getCondition()); + return success(); + } // select %x, c0, %c1 => extui (xor %arg, true) - if (matchPattern(op.getTrueValue(), m_Zero())) - if (matchPattern(op.getFalseValue(), m_One())) { - rewriter.replaceOpWithNewOp( - op, op.getType(), - rewriter.create( - op.getLoc(), op.getCondition(), - rewriter.create( - op.getLoc(), 1, op.getCondition().getType()))); - return success(); - } + if (matchPattern(op.getTrueValue(), m_Zero()) && + matchPattern(op.getFalseValue(), m_One())) { + rewriter.replaceOpWithNewOp( + op, op.getType(), + rewriter.create( + op.getLoc(), op.getCondition(), + rewriter.create( + op.getLoc(), 1, op.getCondition().getType()))); + return success(); + } return failure(); } @@ -1778,10 +1774,9 @@ return falseVal; // select %x, true, false => %x - if (getType().isInteger(1)) - if (matchPattern(getTrueValue(), m_One())) - if (matchPattern(getFalseValue(), m_Zero())) - return condition; + if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) && + matchPattern(getFalseValue(), m_Zero())) + return condition; if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { auto pred = cmp.getPredicate();