diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1615,6 +1615,18 @@ I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + // (a * a) + ((a << 1) + b) * b -> (a + b)^2 + // which is: + // (a * a) + (2 * a * b) + (b * b) -> (a + b)^2 + if (match(RHS, m_Mul(m_Value(A), m_Deferred(A)))) { + if (match(LHS, + m_Mul(m_Add(m_Shl(m_Specific(A), m_SpecificInt(1)), m_Value(B)), + m_Deferred(B)))) { + Value *AB = Builder.CreateAdd(A, B); + return BinaryOperator::CreateMul(AB, AB); + } + } + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res;