diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1411,6 +1411,56 @@ } } + // lshr (add(add(A, B), 1), 1) + // --> + // (A >> 1) + (B >> 1) + (A|B)&1 + + // lshr (add(add(A, B), 0), 1) + // --> + // (A >> 1) + (B >> 1) + (A&B)&1 + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); + if(1 == ShAmtC) { + const APInt *constantInt = nullptr; + // Op0: add (X, 1) + if(match(Op0, + m_Add(m_OneUse(m_Value(X)), m_APInt(constantInt)))) { + unsigned constAmt = constantInt->getZExtValue(); + // constAmt = 1 | 0 + if(constAmt <= 1) { + Value *A = nullptr, *B = nullptr; + // X: add(A, B) + if(match(X, + m_Add(m_OneUse(m_Value(A)), m_OneUse(m_Value(B))))) { + // (A >> 1) + Value *ALshr = Builder.CreateLShr(A, Op1); + // (B >> 1) + Value *BLshr = Builder.CreateLShr(B, Op1); + + APInt Bits = APInt::getLowBitsSet(BitWidth, 1); + Constant *Mask = ConstantInt::get(Ty, Bits); + Value *AB = nullptr; + + if(0 == constAmt) { + // (A&B) + AB = Builder.CreateAnd(A, B); + } + else { + // (A|B) + AB = Builder.CreateOr(A, B); + } + // AB&1 + Value *AB1 = Builder.CreateAnd(AB, Mask); + // final step: ALshr + BLshr + AB1 + Value *Add1 = Builder.CreateAdd(ALshr, BLshr); + auto *newInstr = BinaryOperator::CreateAdd(Add1, AB1); + return newInstr; + } + } + } + } + } + // Transform (x << y) >> y to x & (-1 >> y) if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);