This is an archive of the discontinued LLVM Phabricator instance.

[X86] Combine reduce (add (mul x, y)) to VNNI instruction.
ClosedPublic

Authored by LuoYuanke on Dec 20 2021, 6:57 AM.

Details

Summary

For below C code, we can use VNNI to combine the mul and add operation.

int usdot_prod_qi(unsigned char *restrict a, char *restrict b, int c, int n) {

int i;
for (i = 0; i < 32; i++) {
  c += ((int)a[i] * (int)b[i]);
}
return c;

}

We didn't support the combine acoss basic block in this patch.

Diff Detail

Unit TestsFailed

Event Timeline

LuoYuanke created this revision.Dec 20 2021, 6:57 AM
LuoYuanke requested review of this revision.Dec 20 2021, 6:57 AM
Herald added a project: Restricted Project. · View Herald TranscriptDec 20 2021, 6:57 AM
LuoYuanke edited the summary of this revision. (Show Details)
LuoYuanke edited the summary of this revision. (Show Details)
LuoYuanke edited the summary of this revision. (Show Details)

Any explicit checks for extension/truncation and their bitwidth delta instantly make me suspicious nowadays.
Does this deal with commutativity?
I think what you want to check is the number of known sign bits / known leading zero bits.

LuoYuanke updated this revision to Diff 395436.Dec 20 2021, 7:03 AM

update the title.

LuoYuanke retitled this revision from [X86] Combine reduce(mul x, y) to VNNI instruction. to [X86] Combine reduce (add (mul x, y)) to VNNI instruction..Dec 20 2021, 7:03 AM

Any explicit checks for extension/truncation and their bitwidth delta instantly make me suspicious nowadays.

Pls comments on the code, so that I can easily understand.

Does this deal with commutativity?

At X86ISelLowering.cpp : 41763

I think what you want to check is the number of known sign bits / known leading zero bits.

Pls comments on the code, so that I can easily understand.

lebedev.ri added inline comments.Dec 20 2021, 7:10 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
41756–41774

Any explicit checks for extension/truncation and their bitwidth delta instantly make me suspicious nowadays.
Does this deal with commutativity?
I think what you want to check is the number of known sign bits / known leading zero bits.

craig.topper added inline comments.Dec 20 2021, 8:38 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
41756–41774

If the sign extend side can be proven to be positive, the sign extend might be hidden as zero extend. This is why tryMAddReplacement checks for FreeTruncations and calls ComputeNumSignBits.

41801

Opcode is only used once. Why not just make it part of the getNode call?

LuoYuanke updated this revision to Diff 395661.Dec 21 2021, 6:13 AM

Address Craig and Roman's comments.

LuoYuanke marked an inline comment as done.Dec 21 2021, 6:15 AM
LuoYuanke added inline comments.
llvm/lib/Target/X86/X86ISelLowering.cpp
41756–41774

Thanks Craig and Roman. I enhanced it in the new patch.

LuoYuanke updated this revision to Diff 395662.Dec 21 2021, 6:19 AM

Remove debug code.

Please fix the patch description - both extensions there are signed, is that actually the specification for the intrinsic?

llvm/lib/Target/X86/X86ISelLowering.cpp
41790

I'm not sure i follow.
Why is this okay with negative numbers?

41790–41791

This still does not handle the commutative variant.

41791

You want ComputeMinSignedBits() <= 8 to check for sext-like

RKSimon added inline comments.Dec 21 2021, 8:21 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
41812

createVPDPBUSD ?

42108

Should this be called combineVPDPBUSDPattern? VNNI is the ISA no?

llvm/test/CodeGen/X86/dpbusd.ll
5

You might be able to use a common prefix for some of these to reduce check duplication

7

Drop dso_local?

LuoYuanke added inline comments.Dec 21 2021, 11:13 PM
llvm/lib/Target/X86/X86ISelLowering.cpp
41790

I'm not sure i follow.
Why is this okay with negative numbers?

Here is description for VPDPBUSD from https://software.intel.com/content/www/us/en/develop/download/intel-architecture-instruction-set-extensions-programming-reference.html?wapkw=instruction. The first operand is unsigned, and the second operand is signed.

Multiplies the individual unsigned bytes of the first source operand by the corresponding signed bytes of the second source operand, producing intermediate signed word results. The word results are then summed and accumulated in the destination dword element size operand.
41790–41791

This still does not handle the commutative variant.

Sorry, I don't understand it very well. I do it with " std::swap(Op0, Op1)" in line 41764;

41791

You want ComputeMinSignedBits() <= 8 to check for sext-like

Address Roman and Simon's comments. Thanks for the review.

LuoYuanke marked 5 inline comments as done.Dec 21 2021, 11:17 PM

Could you please explain why there are both the knownbits-based checks, and checks for ISD::SIGN/ZERO_EXTEND nodes?

Could you please explain why there are both the knownbits-based checks, and checks for ISD::SIGN/ZERO_EXTEND nodes?

The VPDPBUSD multiplies the individual unsigned bytes of the first source operand by the corresponding signed bytes of the second source operand, producing intermediate signed word results. The word results are then summed and accumulated in the destination dword element size operand.

For src2, it is signed value, so we don't need to check for ISD::SIGN nodes, because if the signed bits are 1 it is negative value and if the signed bits are 0 it is positive value.
But for src1, it is unsigned value. If it is a positive value it is OK, but if it is a negative value we can't use VPDPBUSD to combine the original nodes. See test case mul_sext_i4i4 and mul_zext_i4i4() in dpbusd_i4.ll. For mul_zext_i4i4 we can use VPDPBUSD, but for mul_sext_i4i4 we can't because the src1 may be negative value.

LuoYuanke added inline comments.Dec 21 2021, 11:44 PM
llvm/lib/Target/X86/X86ISelLowering.cpp
41780

Maybe we can remove IsFreeTruncation() check as Roman mentions.
Roman, do you mean to remove IsFreeTruncation() check?

LuoYuanke added inline comments.Dec 22 2021, 12:32 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
41780

Maybe we can remove IsFreeTruncation() check as Roman mentions.
Roman, do you mean to remove IsFreeTruncation() check?

If I remove the ISD::SIGN/ZERO_EXTEND check, I got crash with below test case in createVPDPBUSD(). I think there is room to improve the patch to cover more pattern. But to be conservatively I'd like to improve it in another patch, so that if we have regression we can revert less code.
Hi Roman,
What do you think?

declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

define dso_local i32 @mul_i4i2(<16 x i4> %b, i32 %c) {
entry:
  %0 = trunc <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15> to <16 x i16>
  %1 = zext <16 x i16> %0 to <16 x i32>
  %2 = zext <16 x i4> %b to <16 x i32>
  %3 = mul nsw <16 x i32> %2, %1
  %4 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
  %op.extra = add nsw i32 %4, %c
  ret i32 %op.extra
}

Format the code as Lint suggested?

llvm/lib/Target/X86/X86ISelLowering.cpp
41810

Can we return false here if Op0 is not ZERO_EXTEND?

41813–41814

How about ANY_EXTEND ? The same below.

41826

It's better to hoist to line 41810. How about ANY_EXTEND?

41870

It is always 2? Better to add a comment to explain.

41887–41888

AVXVNNI implies AVX2. We won't need to split to 128 bits.

42156–42157

Can we generate i32 first then do the truncation?

42162–42163

Can check DCI.isAfterLegalizeDAG() before calling the function instead?

42210

Can ues ExtractVT.getSizeInBits() derectly.

llvm/test/CodeGen/X86/dpbusd.ll
13

This is the only and a strange diff with the AVX512 code. Is there anything wrong in one of each?

Are we missing test cases for 32 x i8 and 64 x i8?

craig.topper added inline comments.Dec 25 2021, 10:13 PM
llvm/lib/Target/X86/X86ISelLowering.cpp
42162–42163

I don't think that works. We should be able to handle 32 x i8 and 64 x i8 which would have zero_extend and sign_extend with illegal result types.

LuoYuanke marked 3 inline comments as done.Dec 26 2021, 5:30 AM

I'll update the patch according to Phoebe and Craig's comments.

llvm/lib/Target/X86/X86ISelLowering.cpp
41813–41814

This check the opcode, so we need check both zero extend and sign extend. I'm not sure if any extend also works, because the upper bits is undefined. What's the signed bit for any extend?

42156–42157

I'm not sure if the result overflow, truncating back to i16 or some other types remain the same value. How about leave it as an enhancement?

llvm/test/CodeGen/X86/dpbusd.ll
13

This test doesn't generate vpdpbusd instruction, so the AVX512VNNI and AVX512VL generate the same code. For other test case, AVX512VNNI can only use zmm register, but AVX512VNNI + AVX512VL can use xmm register.

LuoYuanke updated this revision to Diff 396222.Dec 26 2021, 5:41 AM

Address Phoebe and Craig's comments.

LuoYuanke added inline comments.Dec 26 2021, 5:43 AM
llvm/test/CodeGen/X86/dpbusd.ll
13

For vpdpbusd_64xi32, the result is the same.

craig.topper added inline comments.Dec 27 2021, 10:34 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
41813–41814

It is undefined and ComputeMinSignedBits will return BitWidth - 1 for it.

41818

You can use Op.getOperand(0).getScalarValueSizeInBits() to simplify this

41828

Can we use DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8 to make this more readable?

41854

Why are Ext0 and Ext1 passed by const reference? SDValue should be passed by value.

41862

Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?

41868

Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.

42157

This code isn't handling vpdpwssd so why mention it here?

42169

Is this code valid for this transform? There's a large comment of justification for why it is ok for SAD. I think I only saw a test for the SIGN_EXTEND case?

craig.topper added inline comments.Dec 27 2021, 10:44 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
42169

Oops I see the other test. I need to think about the math.

craig.topper added inline comments.Dec 27 2021, 11:02 AM
llvm/lib/Target/X86/X86ISelLowering.cpp
42169

I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.

I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?

LuoYuanke added inline comments.Dec 28 2021, 7:13 PM
llvm/lib/Target/X86/X86ISelLowering.cpp
41862

Can we just TRUNCATE the Ext nodes without assuming they are extend nodes. That way it just works when you support constants in the future?

Good idea. :) I'll update my patch.

41868

Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.

41868

Can we do this without assuming the node is a SIGN/ZERO_EXTEND? Just truncate the original node to Vi8VT.

Yes, that's better.

42157

This code isn't handling vpdpwssd so why mention it here?

My original code covers both vpdpbusd and vpdpwssd. I'll clean it.

42169

I don't think we can do this if the multiply result is zero extended. Each of the 4 multiplies done by vpdpbusd compute a signed 16-bit product that will be sign extended before adding into the accumulator.

I think we also need to verify that the multiply has at least 2x the number of bits of the input. We shouldn't match (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). Does anything prevent that right now?

Really good catch. Thanks.

LuoYuanke updated this revision to Diff 396454.Dec 28 2021, 7:21 PM

Address Craig's comments.

Align the check in X86ISelLowering.cpp and X86PartialReduction.cpp.
Add test case for 2 x i32.

Fix lint issues.

craig.topper added inline comments.Dec 29 2021, 2:59 PM
llvm/lib/Target/X86/X86ISelLowering.cpp
41858

Just name the paramaters ZExt0 and SExt1 and get rid of Ext0 and Ext1?

41868

This just Vi8VT right?

41893

Is this lambda getting anything via & in the capture list or can it just be []

llvm/lib/Target/X86/X86PartialReduction.cpp
99

Why pass 0, nullptr, nullptr when those have default values?

LuoYuanke updated this revision to Diff 396585.Dec 29 2021, 5:46 PM

Address Craig's comments. Thanks, Craig.

craig.topper accepted this revision.Jan 6 2022, 9:48 AM

LGTM other than that one comment.

llvm/lib/Target/X86/X86ISelLowering.cpp
41860

You can drop these ifs. getZExtOrTrunc/getSExtOrTrunc will do nothing if the type already matches.

This revision is now accepted and ready to land.Jan 6 2022, 9:48 AM
LuoYuanke updated this revision to Diff 398047.Jan 6 2022, 8:55 PM

Address Craig's comments and rebase.

This revision was landed with ongoing or failed builds.Jan 7 2022, 5:14 AM
This revision was automatically updated to reflect the committed changes.