This is an archive of the discontinued LLVM Phabricator instance.

[InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`
ClosedPublic

Authored by goldstein.w.n on Feb 16 2023, 3:29 PM.

Details

Summary

Forked from D142901 to deduce more nsw/nuw flag for the output
shl.

We can handle the following cases + some nsw/nuw flags:

The rationale for doing this all in InstCombine rather than handling
the constant shl cases in InstSimplify is we often create a new
instruction because we are able to deduce more nsw/nuw flags than
the original instruction had.

Diff Detail

Event Timeline

goldstein.w.n created this revision.Feb 16 2023, 3:29 PM
Herald added a project: Restricted Project. · View Herald TranscriptFeb 16 2023, 3:29 PM
Herald added a subscriber: hiraditya. · View Herald Transcript
goldstein.w.n requested review of this revision.Feb 16 2023, 3:29 PM
Herald added a project: Restricted Project. · View Herald TranscriptFeb 16 2023, 3:29 PM

Use match instead of handwritten logic

sdesmalen added inline comments.Mar 14 2023, 2:23 AM
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
1766

It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.

APInt Y, Z;
const APInt *MatchY = nullptr, *MatchZ = nullptr;

// Match and normalise shift-amounts to multiplications
if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
  Y = *MatchY;
  if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
    //  rem(mul(x, y), shl(x, z))
    Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
  else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
    //  rem(mul(x, y), mul(x, z))
    Z = *MatchZ;
} else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
  Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
  if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
    //  rem(shl(x, y), shl(x, z))
    Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
  else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
    //  rem(shl(x, y), mul(x, z))
    Z = *MatchZ;
}

if (!MatchY || !MatchZ)
  return nullptr;
1781–1783

I can't see this case being tested anywhere. I'd suggest singling this case out from the other logic and handling it separately. Maybe you can move that to another patch (or make it part of D143417, which tries to handle a similar case). That makes this patch a bit simpler and avoids the need for GetBinOpOut.

Split of (shl Z, X), (shl Y, X) case. Remove the Shift{Y|Z|X} flags

goldstein.w.n marked an inline comment as done.
goldstein.w.n added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
1766

It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.

Inlined the ShiftY/ShiftZ

Prefer keeping the 4 explicit cases (removed the ShiftX case and moved to next patch). Think its clearer to have the cases each explicitly laid out, rather than having nest if/else statements. LMK if thats okay, will change if you feel strongly.

APInt Y, Z;
const APInt *MatchY = nullptr, *MatchZ = nullptr;

// Match and normalise shift-amounts to multiplications
if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
  Y = *MatchY;
  if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
    //  rem(mul(x, y), shl(x, z))
    Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
  else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
    //  rem(mul(x, y), mul(x, z))
    Z = *MatchZ;
} else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
  Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
  if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
    //  rem(shl(x, y), shl(x, z))
    Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
  else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
    //  rem(shl(x, y), mul(x, z))
    Z = *MatchZ;
}

if (!MatchY || !MatchZ)
  return nullptr;
1781–1783

I can't see this case being tested anywhere. I'd suggest singling this case out from the other logic and handling it separately. Maybe you can move that to another patch (or make it part of D143417, which tries to handle a similar case). That makes this patch a bit simpler and avoids the need for GetBinOpOut.

There are some tests for it, though they are all non-constant cases (so it starts to matter in the next patch).

I split it to a new patch (see test=D147107, impl=D147108) and I added tests for the constant version of case.

goldstein.w.n marked an inline comment as done.Mar 28 2023, 9:06 PM
sdesmalen accepted this revision.Apr 19 2023, 6:41 AM

LGTM with comment about redundant condition addressed.

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
1766

This is merely a suggestion, I'll leave it to you whether to adopt.

From looking at the new way you've structured the code, it occurred to me that it can also be written as this:

// If V is not nullptr, it will be matched using m_Specific.
auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
  const APInt *Tmp = nullptr;
  if ((!V && match(Op, m_c_Mul(m_Value(V), m_APInt(Tmp)))) ||
      (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
    C = *Tmp;
  else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
           (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
    C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
  return Tmp != nullptr;
};

APInt Y, Z;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
  return nullptr;

Which avoids having to spell out all the permutations. It also avoids the need for the AdjustedY and AdjustedZ variables.

1774–1776

You can remove this condition, because InstCombine will already have canonicalised the constant to the RHS.

This revision is now accepted and ready to land.Apr 19 2023, 6:41 AM

Use cleaner method for matching

goldstein.w.n marked 2 inline comments as done.Apr 19 2023, 8:32 PM
goldstein.w.n added inline comments.
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
1766

Done, likewise for D147108. But couldn't find a clean way to do it for
D143417. In D143417 its not longer only APInt matches so need
branching logic for either V or Y/Z (Tmp) being nullptr.

goldstein.w.n marked an inline comment as done.

Rebase