This is an archive of the discontinued LLVM Phabricator instance.

[InstCombine] Contracting x^2 + 2*x*y + y^2 to (x + y)^2 (float)
ClosedPublic

Authored by rainerzufalldererste on Aug 16 2023, 6:52 AM.

Diff Detail

Event Timeline

Herald added a project: Restricted Project. · View Herald TranscriptAug 16 2023, 6:52 AM
Herald added a subscriber: hiraditya. · View Herald Transcript
rainerzufalldererste requested review of this revision.Aug 16 2023, 6:52 AM
Herald added a project: Restricted Project. · View Herald TranscriptAug 16 2023, 6:52 AM
goldstein.w.n added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

This match code is basically identical to foldSquareSumInts. The only difference other than FMul vs Mul is you do match FMul(A, 2) for floats and m_Shl(A, 1) for ints.
Can you make the match code a helper that takes either fmul/2x matcher (or just lambda wrapping) so it can be used for SumFloat / SumInt?

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Does that imply that m_c_FAdd can simply be replaced with m_c_Add and will continue to match properly for floating point values as well?
I presume that would entail partially matching another pattern and then deferring the actual check for the mul2 match, as BinaryOp_match<RHS, LHS, OpCode> would have different OpCodes for FMul and Shl, which sounds like a huge mess to me; or is there a cleaner way to do that?

Something like this sadly doesn't compile (as the lambda return type is ambiguous):

const auto FpMul2Matcher = [](auto &value) {
  return m_FMul(value, m_SpecificFP(2.0));
};
const auto IntMul2Matcher = [](auto &value) {
  return m_Shl(value, m_SpecificInt(1));
};
const auto Mul2Matcher = FP ? FpMul2Matcher : IntMul2Matcher;
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Even something like this shouldn't work.

template <typename TMul2, typename TCAdd, typename TMul>
static bool MatchesSquareSum(BinaryOperator &I, Value *&A, Value *&B,
                             const TMul2 &Mul2, const TCAdd &CAdd,
                             const TMul &Mul) {

  // (a * a) + (((a * 2) + b) * b)
  bool Matches =
      match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
                     m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
                                  m_Deferred(B)))));

  // ((a * b) * 2)  or ((a * 2) * b)
  // +
  // (a * a + b * b) or (b * b + a * a)
  if (!Matches) {
    Matches =
        match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
                                   m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
                       m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
                                     Mul(m_Deferred(B), m_Deferred(B))))));
  }

  return Matches;
}

I agree that it's messy to have duplicate code, but with the way op-codes are used as template parameters I don't see a way without template specialization to do this nicely; and with template specialization it's even more of a beast.
Am I missing some obvious way built into llvm/InstCombine to do this nicely?

goldstein.w.n added inline comments.Aug 16 2023, 5:07 PM
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Why doesn't that code work?

rainerzufalldererste added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Assuming TMul2 etc. to be a lambda, the return type couln't be consistent, as for both m_FMul and m_Shl it'd be BinaryOp_match<RHS, LHS, OpCode>, with the same OpCode for each invocation, but different RHS and LHS. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (Mul2, Mul, CAdd) that simply map to the correct functions for FAdd/Add etc.
However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.

rainerzufalldererste marked an inline comment as done.Aug 18 2023, 5:02 PM
rainerzufalldererste added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Have you been able to come up with some better ideas? Maybe it's not _that_ terrible to go down the template specialization route, as many of the integer optimizations may have similar counterparts in FP with nsz and reassoc. Not sure how many of them are already handled twice, but there's a chance one could simplify this process by providing template specialized m_XAdd<IsFP>(LHS, RHS) etc. However, I'm not sure if I'm the right person to pass judgement on something that large, as I'm still very new to both LLVM and InstCombine.

goldstein.w.n added inline comments.Aug 18 2023, 10:51 PM
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Assuming TMul2 etc. to be a lambda, the return type couln't be consistent, as for both m_FMul and m_Shl it'd be BinaryOp_match<RHS, LHS, OpCode>, with the same OpCode for each invocation, but different RHS and LHS. One could make this work with macros, but I don't know the LLVM stance on macros, or with templace specialization, where there'd be a specialized struct with three functions (Mul2, Mul, CAdd) that simply map to the correct functions for FAdd/Add etc.

For the TMul2 don't you only need a single Value?
Instead of passing a BinaryOperator, you could just pass a lambda i.e:

auto FPMul2 = [](Value *& A) {
   return match(m_FMul(m_Value(A), m_SpecificFP(2));
};

...
auto IntMul2 = [](Value *&A) {
  return match(m_Shl(m_Value(A), m_SpecificInt(1));
};

Don't see why the same isn't true for mul/add (although two values then).

However, I honestly think that the current implementation is the cleanest way to do it. I'm also not a big fan of code duplication, but the discussed alternatives seem a lot messier to me.

rainerzufalldererste marked an inline comment as done.Aug 19 2023, 7:24 AM
rainerzufalldererste added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

Regarding the LHS and RHS, you are correct, I misspoke. The OpCode and RHS are consistent, but LHS isn't. There are multiple cases where TMul2 is used:

Mul2(m_Deferred(A)
Mul2(Mul(m_Value(A), m_Value(B))
Mul2(m_Value(A))

All of these parameters have different types, therefore the return type of this lambda would also be different in every case. So if the parameter were Value *&, this wouldn't be a problem at all, but that's simply not the case. Is there a way to cast these types to Value *& somehow (without capturing them separately and then matching things again against the sub-match-lambda)?

mDeferred returns deferredval_ty<Value>.
Mul(m_Value(), m_Value() returns either BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::FMul> or BinaryOp_match<bind_ty<Value>, bind_ty<Value>, Instruction::Mul>.
m_Value returns bind_ty<Value>.

These types aren't compatible, so the template can't deduce a consistent type even from auto-parameter lambdas. Same with Mul & CAdd.

Apart from that, I'm a bit confused about the match in your comment, as that's not quite applicable, unless we're previously matching parts of the match and then checking them against this follow-up matcher lambda, which - even if we were to do that - would end up in a large mess, as that's not only the case with Mul2, but also CAdd & Mul then, turning these two large matches into a ton of tiny matches.

Otherwise, I'm not quite sure why I'm explaining compilation errors here, unless I'm missing something very obvious or am completely missing the point.

This, however, isn't valid C++ code:

template <typename TMul2, typename TCAdd, typename TMul>
static std::tuple<bool, Value *, Value *>
MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
                 const TMul &Mul) {
  Value *A, *B;

  // (a * a) + (((a * 2) + b) * b)
  if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
                     m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
                                  m_Deferred(B))))))
    return std::make_tuple(true, A, B);

  // ((a * b) * 2)  or ((a * 2) * b)
  // +
  // (a * a + b * b) or (b * b + a * a)
  return std::make_tuple(
      match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
                                 m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
                     m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
                                   Mul(m_Deferred(B), m_Deferred(B)))))),
      A, B);
}

// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// if `FP`: requires `nsz` and `reassoc`.
Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
  if (FP) {
    assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
           "Assumption mismatch");
  }

  std::tuple<bool, Value *, Value *> Match;

  if (FP) {
    Match = MatchesSquareSum(
        I, [](auto &V) { return m_FMul(V, m_SpecificFP(2.0)); },
        [](auto &L, auto &R) { return m_c_FAdd(L, R); },
        [](auto &L, auto &R) { return m_FMul(L, R); });
  } else {
    Match = MatchesSquareSum(
        I, [](auto &V) { return m_Shl(V, m_SpecificInt(1)); },
        [](auto &L, auto &R) { return m_c_Add(L, R); },
        [](auto &L, auto &R) { return m_Mul(L, R); });
  }

  // if one of them matches: -> (a + b)^2
  if (std::get<0>(Match)) {
    Value *AB =
        Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}

This _is_ valid C++ code, but uses template specialization to get around the previous type-ambiguity issues:

template <bool IsFP> struct XMul;

template <> struct XMul<false> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_Mul(L, R);
  }
};

template <> struct XMul<true> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_FMul(L, R);
  }
};

template <bool IsFP> struct XCAdd;

template <> struct XCAdd<false> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_c_Add(L, R);
  }
};

template <> struct XCAdd<true> {
  template <typename LHS, typename RHS>
  inline auto operator()(const LHS &L, const RHS &R) const {
    return m_c_FAdd(L, R);
  }
};

template <bool IsFP> struct XMul2;

template <> struct XMul2<false> {
  template <typename LHS> inline auto operator()(const LHS &L) const {
    return m_Shl(L, m_SpecificInt(1));
  }
};

template <> struct XMul2<true> {
  template <typename LHS> inline auto operator()(const LHS &L) const {
    return m_FMul(L, m_SpecificFP(2.0));
  }
};

template <typename TMul2, typename TCAdd, typename TMul>
static std::tuple<bool, Value *, Value *>
MatchesSquareSum(BinaryOperator &I, const TMul2 &Mul2, const TCAdd &CAdd,
                 const TMul &Mul) {
  Value *A, *B;

  // (a * a) + (((a * 2) + b) * b)
  if (match(&I, CAdd(m_OneUse(Mul(m_Value(A), m_Deferred(A))),
                     m_OneUse(Mul(CAdd(Mul2(m_Deferred(A)), m_Value(B)),
                                  m_Deferred(B))))))
    return std::make_tuple(true, A, B);

  // ((a * b) * 2)  or ((a * 2) * b)
  // +
  // (a * a + b * b) or (b * b + a * a)
  return std::make_tuple(
      match(&I, CAdd(m_CombineOr(m_OneUse(Mul2(Mul(m_Value(A), m_Value(B)))),
                                 m_OneUse(Mul(Mul2(m_Value(A)), m_Value(B)))),
                     m_OneUse(CAdd(Mul(m_Deferred(A), m_Deferred(A)),
                                   Mul(m_Deferred(B), m_Deferred(B)))))),
      A, B);
}

// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// if `FP`: requires `nsz` and `reassoc`.
Instruction *InstCombinerImpl::foldSquareSum(BinaryOperator &I, const bool FP) {
  if (FP) {
    assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
           "Assumption mismatch");
  }

  const std::tuple<bool, Value *, Value *> Match =
      FP ? MatchesSquareSum(I, XMul2<true>(), XCAdd<true>(), XMul<true>())
         : MatchesSquareSum(I, XMul2<false>(), XCAdd<false>(), XMul<false>());

  // if one of them matches: -> (a + b)^2
  if (std::get<0>(Match)) {
    Value *AB =
        Builder.CreateFAddFMF(std::get<1>(Match), std::get<2>(Match), &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}
goldstein.w.n added inline comments.Aug 19 2023, 11:25 AM
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1088

How about something along the lines of:

template <unsigned OpcMul, unsigned OpcAdd, unsigned OpcMul2, typename Mul2Rhs>
static bool foldSquareSum(BinaryOperator &I, Mul2Rhs MRhs, Value *&AOut,
                                  Value *&BOut) {
  Value *A, *B;
  bool Matches = match(
      &I,
      m_c_BinOp(OpcAdd, m_OneUse(m_BinOp(OpcMul, m_Value(A), m_Deferred(A))),
                m_OneUse(m_BinOp(
                    OpcMul,
                    m_c_BinOp(OpcAdd, m_BinOp(OpcMul2, m_Deferred(A), MRhs),
                              m_Value(B)),
                    m_Deferred(B)))));
  if (!Matches) {
    Matches = match(
        &I,
        m_c_BinOp(
            OpcAdd,
            m_CombineOr(
                m_OneUse(m_BinOp(
                    OpcMul2, m_BinOp(OpcMul, m_Value(A), m_Value(B)), MRhs)),
                m_OneUse(m_BinOp(OpcMul, m_BinOp(OpcMul2, m_Value(A), MRhs),
                                 m_Value(B)))),
            m_OneUse(
                m_c_BinOp(OpcAdd, m_BinOp(OpcMul, m_Deferred(A), m_Deferred(A)),
                          m_BinOp(OpcMul, m_Deferred(B), m_Deferred(B))))));
  }
  AOut = A;
  BOut = B;
  return Matches;
}


// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
Instruction *InstCombinerImpl::foldSquareSumInts(BinaryOperator &I) {
  Value *A, *B;

  bool Matches =
      foldSquareSum<Instruction::Mul, Instruction::Add, Instruction::Shl>(
          I, m_SpecificInt(1), A, B);
  // if one of them matches: -> (a + b)^2
  if (Matches) {
    Value *AB = Builder.CreateAdd(A, B);
    return BinaryOperator::CreateMul(AB, AB);
  }

  return nullptr;
}


// Fold variations of a^2 + 2*a*b + b^2 -> (a + b)^2
// Requires `nsz` and `reassoc`.

Instruction *InstCombinerImpl::foldSquareSumFloat(BinaryOperator &I) {
  Value *A, *B;

  assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch");

  bool Matches =
      foldSquareSum<Instruction::FMul, Instruction::FAdd, Instruction::FMul>(
          I, m_SpecificFP(2.0), A, B);

  // if one of them matches: -> (a + b)^2
  if (Matches) {
    Value *AB = Builder.CreateFAddFMF(A, B, &I);
    return BinaryOperator::CreateFMulFMF(AB, AB, &I);
  }

  return nullptr;
}

Needs comments/whatnot but don't see why this would fallshort.
All the InstCombine tests pass with this (I assume including all the tests relevant to int/fp version of this).

rainerzufalldererste marked an inline comment as done.

How do you like this? I've made most of the template parameters default to the correct type to keep the invocation cleaner. Not sure if template specializing (only for m_SpecificInt / m_SpecificFP in the matcher function would be a good idea, as it'd make the invocation even cleaner, but the matcher a bit more complicated. However, considering that this little template monster is the replacement for slight code duplication, this may be our implementation of choice. Let me know what you think!

rainerzufalldererste marked 2 inline comments as done.Aug 25 2023, 6:34 AM
goldstein.w.n added inline comments.Aug 25 2023, 10:53 AM
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
1013

In fact if you do the bool template approach, I don't think MulOp, AddOp, or Mul2Op even need to be template parameters. You can just set them as unsigned values at the top of matchesSquareSum.
i.e: unsigned MulOp = FP ? Instruction::FMul : Instruction::Mul;
Imo that ends up being cleaner.

1042

comment what false means. I.e </*FP*/false>

1054

ibid.

Well spotted, much cleaner now. Comments added as requested.

rainerzufalldererste marked 3 inline comments as done.Aug 25 2023, 11:25 AM

LGTM.
I'm by no means an expert in FP semantics. @arsenm any chance you can quickly verify the FP checks are correct?

goldstein.w.n accepted this revision.Aug 25 2023, 11:29 AM

Wait a few days or until matt signs off as well before pushing please.

This revision is now accepted and ready to land.Aug 25 2023, 11:29 AM

all good, I don't have commit access anyways.

Updated to current head. Please commit for me when the build completes successfully. Thanks!