diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -198,14 +198,14 @@ return getLhs(); // addi(subi(a, b), b) -> a - if (auto sub = getLhs().getDefiningOp()) - if (getRhs() == sub.getRhs()) - return sub.getLhs(); + if (auto sub = getLhs().getDefiningOp(); + sub && getRhs() == sub.getRhs()) + return sub.getLhs(); // addi(b, subi(a, b)) -> a - if (auto sub = getRhs().getDefiningOp()) - if (getLhs() == sub.getRhs()) - return sub.getLhs(); + if (auto sub = getRhs().getDefiningOp(); + sub && getLhs() == sub.getRhs()) + return sub.getLhs(); return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); @@ -607,9 +607,9 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); /// or(x, ) -> - if (auto rhsAttr = operands[1].dyn_cast_or_null()) - if (rhsAttr.getValue().isAllOnes()) - return rhsAttr; + if (auto rhsAttr = operands[1].dyn_cast_or_null(); + rhsAttr && rhsAttr.getValue().isAllOnes()) + return rhsAttr; return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); @@ -627,9 +627,9 @@ if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); /// xor(xor(x, a), a) -> x - if (arith::XOrIOp prev = getLhs().getDefiningOp()) - if (prev.getRhs() == getRhs()) - return prev.getLhs(); + if (arith::XOrIOp prev = getLhs().getDefiningOp(); + prev && prev.getRhs() == getRhs()) + return prev.getLhs(); return constFoldBinaryOp( operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); @@ -1414,18 +1414,17 @@ } if (matchPattern(getRhs(), m_Zero())) { - if (auto extOp = getLhs().getDefiningOp()) { - // extsi(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast().getWidth() == 1 && - getPredicate() == arith::CmpIPredicate::ne) - return extOp.getOperand(); - } - if (auto extOp = getLhs().getDefiningOp()) { - // extui(%x : i1 -> iN) != 0 -> %x - if (extOp.getOperand().getType().cast().getWidth() == 1 && - getPredicate() == arith::CmpIPredicate::ne) - return extOp.getOperand(); - } + // extsi(%x : i1 -> iN) != 0 -> %x + if (auto extOp = getLhs().getDefiningOp(); + extOp && extOp.getOperand().getType().cast().getWidth() == 1 && + getPredicate() == arith::CmpIPredicate::ne) + return extOp.getOperand(); + + // extui(%x : i1 -> iN) != 0 -> %x + if (auto extOp = getLhs().getDefiningOp(); + extOp && extOp.getOperand().getType().cast().getWidth() == 1 && + getPredicate() == arith::CmpIPredicate::ne) + return extOp.getOperand(); } // Move constant to the right side. @@ -1616,10 +1615,9 @@ // to distinguish it from one less than that value. if ((int)intWidth > mantissaWidth) { // Conversion would lose accuracy. Check if loss can impact comparison. - int exponent = ilogb(rhs); - if (exponent == APFloat::IEK_Inf) { - int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); - if (maxExponent < (int)valueBits) { + if (int exponent = ilogb(rhs); exponent == APFloat::IEK_Inf) { + if (int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics())); + maxExponent < (int)valueBits) { // Conversion could create infinity. return failure(); } @@ -1925,8 +1923,8 @@ return condition; if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { - auto pred = cmp.getPredicate(); - if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { + if (auto pred = cmp.getPredicate(); + pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { auto cmpLhs = cmp.getLhs(); auto cmpRhs = cmp.getRhs();