diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -995,6 +995,55 @@ return nullptr; } +static bool MatchIntSquared(Value *V, Value *&A) { + return match(V, m_Mul(m_Value(A), m_Deferred(A))); +} + +static bool MatchInt2ABPlusASquared(Value *V, Value *A, Value *&B) { + return match(V, + m_Mul(m_Add(m_Shl(m_Specific(A), m_SpecificInt(1)), m_Value(B)), + m_Deferred(B))); +} + +static bool MatchInt2AB(Value *V, Value *&A, Value *&B) { + return match(V, m_Shl(m_Mul(m_Value(A), m_Value(B)), m_SpecificInt(1))) || + match(V, m_Mul(m_Shl(m_Value(A), m_SpecificInt(1)), m_Value(B))); +} + +static bool MatchIntASquaredPlusBSquared(Value *V, Value *A, Value *B) { + return match(V, m_Add(m_Mul(m_Specific(A), m_Specific(A)), + m_Mul(m_Specific(B), m_Specific(B)))) || + match(V, m_Add(m_Mul(m_Specific(B), m_Specific(B)), + m_Mul(m_Specific(A), m_Specific(A)))); +} + +Instruction *InstCombinerImpl::foldSquareSumInts(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *B; + + // (a * a) + (((a << 1) + b) * b) + bool Matches = + (MatchIntSquared(LHS, A) && MatchInt2ABPlusASquared(RHS, A, B)) || + (MatchIntSquared(RHS, A) && MatchInt2ABPlusASquared(LHS, A, B)); + + // ((a * b) << 1) or ((a << 1) * b) + // + + // (a * a + b * b) or (b * b + a * a) + if (!Matches) { + Matches = + (MatchInt2AB(LHS, A, B) && MatchIntASquaredPlusBSquared(RHS, A, B)) || + (MatchInt2AB(RHS, A, B) && MatchIntASquaredPlusBSquared(LHS, A, B)); + } + + // if one of them matches: -> (a + b)^2 + if (Matches) { + Value *AB = Builder.CreateAdd(A, B); + return BinaryOperator::CreateMul(AB, AB); + } + + return nullptr; +} + // Matches multiplication expression Op * C where C is a constant. Returns the // constant value in C and the other operand in Op. Returns true if such a // match is found. @@ -1615,6 +1664,9 @@ I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + if (Instruction *Res = foldSquareSumInts(I)) + return Res; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -535,6 +535,8 @@ Instruction *foldAddWithConstant(BinaryOperator &Add); + Instruction *foldSquareSumInts(BinaryOperator &I); + /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *foldPHIArgOpIntoPHI(PHINode &PN); diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll --- a/llvm/test/Transforms/InstCombine/add.ll +++ b/llvm/test/Transforms/InstCombine/add.ll @@ -3096,4 +3096,166 @@ ret i32 %add } +define i32 @add_reduce_sqr_sum_nsw(i32 %0, i32 %1) { +; CHECK-LABEL: @add_reduce_sqr_sum_nsw( +; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[TMP0:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[TMP3]], [[TMP3]] +; CHECK-NEXT: ret i32 [[TMP4]] +; + %3 = mul nsw i32 %0, %0 + %4 = shl i32 %0, 1 + %5 = add i32 %4, %1 + %6 = mul i32 %5, %1 + %7 = add i32 %6, %3 + ret i32 %7 +} + +define i32 @add_reduce_sqr_sum_u(i32 %0, i32 %1) { +; CHECK-LABEL: @add_reduce_sqr_sum_u( +; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[TMP0:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[TMP3]], [[TMP3]] +; CHECK-NEXT: ret i32 [[TMP4]] +; + %3 = mul i32 %0, %0 + %4 = shl i32 %0, 1 + %5 = add i32 %4, %1 + %6 = mul i32 %5, %1 + %7 = add i32 %6, %3 + ret i32 %7 +} + +define i32 @add_reduce_sqr_sum_nuw(i32 %0, i32 %1) { +; CHECK-LABEL: @add_reduce_sqr_sum_nuw( +; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[TMP0:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[TMP3]], [[TMP3]] +; CHECK-NEXT: ret i32 [[TMP4]] +; + %3 = mul nuw i32 %0, %0 + %4 = mul i32 %0, 2 + %5 = add i32 %4, %1 + %6 = mul nuw i32 %5, %1 + %7 = add i32 %6, %3 + ret i32 %7 +} + +define i32 @add_reduce_sqr_sum_order2(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order2( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twoa = mul i32 %a, 2 + %twoab = mul i32 %twoa, %b + %b_sq = mul i32 %b, %b + %twoab_b2 = add i32 %twoab, %b_sq + %ab2 = add i32 %a_sq, %twoab_b2 + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order2_flipped(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order2_flipped( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twoa = mul i32 %a, 2 + %twoab = mul i32 %twoa, %b + %b_sq = mul i32 %b, %b + %twoab_b2 = add i32 %twoab, %b_sq + %ab2 = add i32 %twoab_b2, %a_sq + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order3(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order3( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twoa = mul i32 %a, 2 + %twoab = mul i32 %twoa, %b + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %twoab, %a2_b2 + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order3_flipped(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order3_flipped( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twoa = mul i32 %a, 2 + %twoab = mul i32 %twoa, %b + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %a2_b2, %twoab + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order4(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order4( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %ab = mul i32 %a, %b + %twoab = mul i32 %ab, 2 + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %twoab, %a2_b2 + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order4_flipped(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order4_flipped( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %ab = mul i32 %a, %b + %twoab = mul i32 %ab, 2 + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %a2_b2, %twoab + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order5(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order5( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B:%.*]], [[A:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twob = mul i32 %b, 2 + %twoab = mul i32 %twob, %a + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %twoab, %a2_b2 + ret i32 %ab2 +} + +define i32 @add_reduce_sqr_sum_order5_flipped(i32 %a, i32 %b) { +; CHECK-LABEL: @add_reduce_sqr_sum_order5_flipped( +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B:%.*]], [[A:%.*]] +; CHECK-NEXT: [[AB2:%.*]] = mul i32 [[TMP1]], [[TMP1]] +; CHECK-NEXT: ret i32 [[AB2]] +; + %a_sq = mul nsw i32 %a, %a + %twob = mul i32 %b, 2 + %twoab = mul i32 %twob, %a + %b_sq = mul i32 %b, %b + %a2_b2 = add i32 %a_sq, %b_sq + %ab2 = add i32 %a2_b2, %twoab + ret i32 %ab2 +} + declare void @llvm.assume(i1)