diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -995,37 +995,74 @@ return nullptr; } -// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2 -Instruction *InstCombinerImpl::foldSquareSumInts(BinaryOperator &I) { - Value *A, *B; +template +using conditional_value = typename std::conditional< + B, std::integral_constant, + std::integral_constant>::type; - // (a * a) + (((a << 1) + b) * b) - bool Matches = match( - &I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_Deferred(A))), - m_OneUse(m_Mul(m_c_Add(m_Shl(m_Deferred(A), m_SpecificInt(1)), - m_Value(B)), - m_Deferred(B))))); +// match variations of a^2 + 2*a*b + b^2 +// +// to reuse the code between the FP and Int versions, the instruction OpCodes +// and constant types have been turned into template parameters. +// +// Mul2Rhs: The constant to perform the multiplicative equivalent of X*2 with; +// should be `m_SpecificFP(2.0)` for FP and `m_SpecificInt(1)` for Int +// (we're matching `X<<1` instead of `X*2` for Int) +template ::value, + unsigned AddOp = + conditional_value::value, + unsigned Mul2Op = + conditional_value::value> +static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A, + Value *&B) { + // (a * a) + (((a * 2) + b) * b) + if (match(&I, m_c_BinOp( + AddOp, m_OneUse(m_BinOp(MupOp, m_Value(A), m_Deferred(A))), + m_OneUse(m_BinOp( + MupOp, + m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs), + m_Value(B)), + m_Deferred(B)))))) + return true; - // ((a * b) << 1) or ((a << 1) * b) + // ((a * b) * 2) or ((a * 2) * b) // + // (a * a + b * b) or (b * b + a * a) - if (!Matches) { - Matches = match( - &I, - m_c_Add(m_CombineOr(m_OneUse(m_Shl(m_Mul(m_Value(A), m_Value(B)), - m_SpecificInt(1))), - m_OneUse(m_Mul(m_Shl(m_Value(A), m_SpecificInt(1)), - m_Value(B)))), - m_OneUse(m_c_Add(m_Mul(m_Deferred(A), m_Deferred(A)), - m_Mul(m_Deferred(B), m_Deferred(B)))))); - } - - // if one of them matches: -> (a + b)^2 - if (Matches) { + return match( + &I, + m_c_BinOp(AddOp, + m_CombineOr( + m_OneUse(m_BinOp( + Mul2Op, m_BinOp(MupOp, m_Value(A), m_Value(B)), M2Rhs)), + m_OneUse(m_BinOp(MupOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs), + m_Value(B)))), + m_OneUse(m_c_BinOp( + AddOp, m_BinOp(MupOp, m_Deferred(A), m_Deferred(A)), + m_BinOp(MupOp, m_Deferred(B), m_Deferred(B)))))); +} + +// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +Instruction *InstCombinerImpl::foldSquareSumInt(BinaryOperator &I) { + Value *A, *B; + if (matchesSquareSum(I, m_SpecificInt(1), A, B)) { Value *AB = Builder.CreateAdd(A, B); return BinaryOperator::CreateMul(AB, AB); } + return nullptr; +} +// Fold floating point variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +// Requires `nsz` and `reassoc`. +Instruction *InstCombinerImpl::foldSquareSumFP(BinaryOperator &I) { + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch"); + Value *A, *B; + if (matchesSquareSum(I, m_SpecificFP(2.0), A, B)) { + Value *AB = Builder.CreateFAddFMF(A, B, &I); + return BinaryOperator::CreateFMulFMF(AB, AB, &I); + } return nullptr; } @@ -1667,7 +1704,7 @@ I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); - if (Instruction *Res = foldSquareSumInts(I)) + if (Instruction *Res = foldSquareSumInt(I)) return Res; if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) @@ -1849,6 +1886,9 @@ if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; + if (Instruction *F = foldSquareSumFP(I)) + return F; + // Try to fold fadd into start value of reduction intrinsic. if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic( m_AnyZeroFP(), m_Value(X))), diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -542,7 +542,8 @@ Instruction *foldAddWithConstant(BinaryOperator &Add); - Instruction *foldSquareSumInts(BinaryOperator &I); + Instruction *foldSquareSumInt(BinaryOperator &I); + Instruction *foldSquareSumFP(BinaryOperator &I); /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. diff --git a/llvm/test/Transforms/InstCombine/fadd.ll b/llvm/test/Transforms/InstCombine/fadd.ll --- a/llvm/test/Transforms/InstCombine/fadd.ll +++ b/llvm/test/Transforms/InstCombine/fadd.ll @@ -620,11 +620,8 @@ define float @fadd_reduce_sqr_sum_varA(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -637,11 +634,8 @@ define float @fadd_reduce_sqr_sum_varA_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order2( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ]], [[MUL]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -654,11 +648,8 @@ define float @fadd_reduce_sqr_sum_varA_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order3( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -671,11 +662,8 @@ define float @fadd_reduce_sqr_sum_varA_order4(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order4( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -688,11 +676,8 @@ define float @fadd_reduce_sqr_sum_varA_order5(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varA_order5( -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A:%.*]], [[A]] -; CHECK-NEXT: [[TWO_A:%.*]] = fmul float [[A]], 2.000000e+00 -; CHECK-NEXT: [[TWO_A_PLUS_B:%.*]] = fadd float [[TWO_A]], [[B:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = fmul float [[TWO_A_PLUS_B]], [[B]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[MUL]], [[A_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_sq = fmul float %a, %a @@ -705,12 +690,8 @@ define float @fadd_reduce_sqr_sum_varB(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -724,12 +705,8 @@ define float @fadd_reduce_sqr_sum_varB_order1(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order1( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ_B_SQ]], [[A_B_2]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -743,12 +720,8 @@ define float @fadd_reduce_sqr_sum_varB_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order2( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[B_SQ]], [[A_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %a, %b @@ -762,12 +735,8 @@ define float @fadd_reduce_sqr_sum_varB_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB_order3( -; CHECK-NEXT: [[A_B:%.*]] = fmul float [[B:%.*]], [[A:%.*]] -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_B]], 2.000000e+00 -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[B:%.*]], [[A:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_b = fmul float %b, %a @@ -781,12 +750,8 @@ define float @fadd_reduce_sqr_sum_varB2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -800,12 +765,8 @@ define float @fadd_reduce_sqr_sum_varB2_order1(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order1( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_SQ_B_SQ]], [[A_B_2]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -819,12 +780,8 @@ define float @fadd_reduce_sqr_sum_varB2_order2(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order2( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float %a, 2.0 @@ -838,12 +795,8 @@ define float @fadd_reduce_sqr_sum_varB2_order3(float %a, float %b) { ; CHECK-LABEL: @fadd_reduce_sqr_sum_varB2_order3( -; CHECK-NEXT: [[A_2:%.*]] = fmul float [[A:%.*]], 2.000000e+00 -; CHECK-NEXT: [[A_B_2:%.*]] = fmul float [[A_2]], [[B:%.*]] -; CHECK-NEXT: [[A_SQ:%.*]] = fmul float [[A]], [[A]] -; CHECK-NEXT: [[B_SQ:%.*]] = fmul float [[B]], [[B]] -; CHECK-NEXT: [[A_SQ_B_SQ:%.*]] = fadd float [[A_SQ]], [[B_SQ]] -; CHECK-NEXT: [[ADD:%.*]] = fadd reassoc nsz float [[A_B_2]], [[A_SQ_B_SQ]] +; CHECK-NEXT: [[TMP1:%.*]] = fadd reassoc nsz float [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[ADD:%.*]] = fmul reassoc nsz float [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret float [[ADD]] ; %a_2 = fmul float 2.0, %a