We can do this by breaking vecreduce into v16i8 vectors generating udot/sdot and concatenating them.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
Sounds like a nice improvement
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15206 | I think this should be something like !(IsValidElementCount && IsValidSize). define i32 @src(ptr %p, i32 %b) { entry: %a64 = load <4 x i8>, ptr %p %a65 = sext <4 x i8> %a64 to <4 x i32> %a66 = mul nsw <4 x i32> %a65, %a65 %a67 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a66) %a = add i32 %a67, %b ret i32 %a } | |
15248 | DotOpcode can be moved out of the loop, and commoned with the version above. Zeroes can be moved up too. |
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15206 |
but then we won't cover the case for v8i8 or shall we change is validElementCount to be |
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15206 | Sorry - I meant with the Op0VT != MVT::v8i8 too. The condition as written here will bail out if both !IsValidElementCount and !IsValidSize, but it seems like it should be bailing if one of them is false. So: if (Op0VT != MVT::v8i8 && (!IsValidElementCount || !IsValidSize)) It could also do bool IsValidElementCount = Op0VT == MVT::v8i8 || Op0VT.getVectorNumElements() % 16 == 0; and then check that if (!IsValidElementCount || !IsValidSize), if you think that is cleaner. |
- Update IsValidElementCount to check for vectors as multiples of 8.
- Move Zero and DotOpcode outside the for loop
Do you have any tests for cases that are a multiple of 8 but not of 16, like <24 x ...? And can you make sure we have the load <4 x i8> test case?
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15206 | Using Op0VT.getVectorNumElements() % 16 == 0 || Op0VT.getVectorNumElements() == 8; would be simpler if you did not care about <24 x types (which might be better as a 16+8, not 3*8). | |
15237 | Some of these types/constants might be incorrect for multiples of 8? |
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15235 | vectos -> vectors | |
15247 | Should this be 4? Or 2 (If 4 is correct it can be pulled up out of the if. But I think 2 might be a better value. I'm not 100% sure what happens when the operands and the type of a concat don't match up). | |
15249 | I += 1 instead of I += Offset? |
Thanks - the patch looks pretty good to me now. The widths that are a multiple 8, but not of 16 (like 24 and 40), whilst they wont be super common, are not doing as well as they could be. Could they instead general v16 "chunks", until there is an v8 remainder? So split v40 into v16+v16+v8. It should hopefully reduce the amount of shuffling and extra instructions in the test_udot_v24i8_nomla like cases, as well as being less s/udot's in total.
Generate v16 checks and then seperate the reminder in v8 chucks this should improve the test case for v24i8
Thanks. The results are looking better now, if we can clean up the code a little then this looks good to me.
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15237 | I don't think this needs floor | |
15254 | This can only ever be 0 or 1, so probably doesn't need the loop. Hopefully this can simplify things a little, as we won't need to concat v8 vectors. | |
15277 | They would need to be 0's I think. Would it be better and simpler to just return vecreduce.add(v16s) + vecreduce.add(v8)? |
Thanks for the updates. This looks very general now. I you agree with the comment about the offset, then this LGTM with that change.
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | ||
---|---|---|
15267 | I was thinking this should be 16 because we have extracted I lots of v16 chunk above, and so should be going that far into the original vector. For example with 24x case should be getting this from offset 16 into the vector. (So in the tests it should be from an ldr d after the existing ldr q. They currently use an ext from the vector). |
I think this should be something like !(IsValidElementCount && IsValidSize).
It is worth adding a v4i8 test if one doesn't exist already: