diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1029,6 +1029,44 @@ return nullptr; } +// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +// Requires `nsz` and `reassoc`. +Instruction *InstCombinerImpl::foldSquareSumFloat(BinaryOperator &I) { + Value *A, *B; + + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch"); + + // (a * a) + (((a * 2) + b) * b) + bool Matches = match( + &I, m_c_FAdd( + m_OneUse(m_FMul(m_Value(A), m_Deferred(A))), + m_OneUse(m_FMul(m_c_FAdd(m_FMul(m_Deferred(A), m_SpecificFP(2.0)), + m_Value(B)), + m_Deferred(B))))); + + // ((a * b) * 2) or ((a * 2) * b) + // + + // (a * a + b * b) or (b * b + a * a) + if (!Matches) { + Matches = match( + &I, m_c_FAdd(m_CombineOr( + m_OneUse(m_FMul(m_FMul(m_Value(A), m_Value(B)), + m_SpecificFP(2.0))), + m_OneUse(m_FMul(m_FMul(m_Value(A), m_SpecificFP(2.0)), + m_Value(B)))), + m_OneUse(m_c_FAdd(m_FMul(m_Deferred(A), m_Deferred(A)), + m_FMul(m_Deferred(B), m_Deferred(B)))))); + } + + // if one of them matches: -> (a + b)^2 + if (Matches) { + Value *AB = Builder.CreateFAddFMF(A, B, &I); + return BinaryOperator::CreateFMulFMF(AB, AB, &I); + } + + return nullptr; +} + // Matches multiplication expression Op * C where C is a constant. Returns the // constant value in C and the other operand in Op. Returns true if such a // match is found. @@ -1831,6 +1869,9 @@ if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; + if (Instruction *F = foldSquareSumFloat(I)) + return F; + // Try to fold fadd into start value of reduction intrinsic. if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic( m_AnyZeroFP(), m_Value(X))), diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -543,6 +543,7 @@ Instruction *foldAddWithConstant(BinaryOperator &Add); Instruction *foldSquareSumInts(BinaryOperator &I); + Instruction *foldSquareSumFloat(BinaryOperator &I); /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. diff --git a/llvm/test/Transforms/InstCombine/fadd.ll b/llvm/test/Transforms/InstCombine/fadd.ll --- a/llvm/test/Transforms/InstCombine/fadd.ll +++ b/llvm/test/Transforms/InstCombine/fadd.ll @@ -620,11 +620,8 @@ define float @fadd_reduce_sqr_sum_varA(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -637,11 +634,8 @@ define float @fadd_reduce_sqr_sum_varA_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order2( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ]], [[MUL]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -654,11 +648,8 @@ define float @fadd_reduce_sqr_sum_varA_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order3( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -671,11 +662,8 @@ define float @fadd_reduce_sqr_sum_varA_order4(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order4( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -688,11 +676,8 @@ define float @fadd_reduce_sqr_sum_varA_order5(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order5( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -705,12 +690,8 @@ define float @fadd_reduce_sqr_sum_varB(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -724,12 +705,8 @@ define float @fadd_reduce_sqr_sum_varB_order1(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order1( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ_B_SQ]], [[A_B_2]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -743,12 +720,8 @@ define float @fadd_reduce_sqr_sum_varB_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order2( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[B_SQ]], [[A_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -762,12 +735,8 @@ define float @fadd_reduce_sqr_sum_varB_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order3( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[B:%.*]], [[A:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[B:%.*]], [[A:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %b, %a @@ -781,12 +750,8 @@ define float @fadd_reduce_sqr_sum_varB2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -800,12 +765,8 @@ define float @fadd_reduce_sqr_sum_varB2_order1(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order1( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ_B_SQ]], [[A_B_2]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -819,12 +780,8 @@ define float @fadd_reduce_sqr_sum_varB2_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order2( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -838,12 +795,8 @@ define float @fadd_reduce_sqr_sum_varB2_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order3( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float 2.0, %a