diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -11745,8 +11745,8 @@ return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0)); } -// VECREDUCE_ADD( EXTEND(v16i8_type) ) to -// VECREDUCE_ADD( DOTv16i8(v16i8_type) ) +// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce +// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one)) static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, const AArch64Subtarget *ST) { SDValue Op0 = N->getOperand(0); @@ -11761,12 +11761,13 @@ return SDValue(); EVT Op0VT = Op0.getOperand(0).getValueType(); - if (Op0VT != MVT::v16i8) + if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8) return SDValue(); SDLoc DL(Op0); SDValue Ones = DAG.getConstant(1, DL, Op0VT); - SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); + SDValue Zeros = + DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32); auto DotOpcode = (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT; SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll --- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll +++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll @@ -31,10 +31,10 @@ ; CHECK-LABEL: test_udot_v8i8_nomla: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr d0, [x0] -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.4s, v0.4h, #0 -; CHECK-NEXT: uaddw2 v0.4s, v1.4s, v0.8h -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v1.2d, #0000000000000000 +; CHECK-NEXT: movi v2.8b, #1 +; CHECK-NEXT: udot v1.2s, v2.8b, v0.8b +; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -72,10 +72,10 @@ ; CHECK-LABEL: test_sdot_v8i8_nomla: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr d0, [x0] -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.4s, v0.4h, #0 -; CHECK-NEXT: saddw2 v0.4s, v1.4s, v0.8h -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v1.2d, #0000000000000000 +; CHECK-NEXT: movi v2.8b, #1 +; CHECK-NEXT: sdot v1.2s, v2.8b, v0.8b +; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -207,14 +207,11 @@ define i32 @test_udot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) { ; CHECK-LABEL: test_udot_v8i8_double_nomla: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v2.8b, #0 -; CHECK-NEXT: ushll v2.4s, v0.4h, #0 -; CHECK-NEXT: ushll v3.4s, v1.4h, #0 -; CHECK-NEXT: uaddw2 v0.4s, v2.4s, v0.8h -; CHECK-NEXT: uaddw2 v1.4s, v3.4s, v1.8h -; CHECK-NEXT: add v0.4s, v0.4s, v1.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v1.2d, #0000000000000000 +; CHECK-NEXT: movi v3.8b, #1 +; CHECK-NEXT: udot v1.2s, v3.8b, v2.8b +; CHECK-NEXT: udot v1.2s, v3.8b, v0.8b +; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -322,14 +319,11 @@ define i32 @test_sdot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) { ; CHECK-LABEL: test_sdot_v8i8_double_nomla: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.8h, v2.8b, #0 -; CHECK-NEXT: sshll v2.4s, v0.4h, #0 -; CHECK-NEXT: sshll v3.4s, v1.4h, #0 -; CHECK-NEXT: saddw2 v0.4s, v2.4s, v0.8h -; CHECK-NEXT: saddw2 v1.4s, v3.4s, v1.8h -; CHECK-NEXT: add v0.4s, v0.4s, v1.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v1.2d, #0000000000000000 +; CHECK-NEXT: movi v3.8b, #1 +; CHECK-NEXT: sdot v1.2s, v3.8b, v2.8b +; CHECK-NEXT: sdot v1.2s, v3.8b, v0.8b +; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: