diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1615,6 +1615,20 @@ I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + // (x * x) + ((x << 1) + y) * y -> (x * y)^2 + // which is: + // (x * x) + (2 * x * y) + (y * y) -> (x * y)^2 + if (match(RHS, m_Mul(m_Value(A), m_Value(B))) && A == B) { + Value *C; + if (match(LHS, m_Mul(m_Add(m_Shl(m_Specific(A), m_SpecificInt(1)), + m_Value(C)), + m_Value(B))) && + B == C) { + Value *AB = Builder.CreateAdd(A, B); + return BinaryOperator::CreateMul(AB, AB, I.getName()); + } + } + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; @@ -1846,6 +1860,18 @@ return Result; } + // (x * x) + ((2 * x) + y) * y -> (x * y)^2 + // which is: + // (x * x) + (2 * x * y) + (y * y) -> (x * y)^2 + if (match(RHS, m_FMul(m_Value(X), m_Value(Y))) && X == Y) { + Value *Z; + if (match(LHS, m_FMul(m_FAdd(m_FMul(m_Specific(X), m_SpecificFP(2.0)), + m_Value(Z)), m_Value(Y))) && Z == Y) { + Value *XY = Builder.CreateFAdd(X, Y); + return BinaryOperator::CreateFMulFMF(XY, XY, &I); + } + } + return nullptr; }