diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -11747,31 +11747,46 @@ // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce // vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one)) +// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B)) static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, const AArch64Subtarget *ST) { SDValue Op0 = N->getOperand(0); - if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) - return SDValue(); - - if (Op0.getValueType().getVectorElementType() != MVT::i32) + if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 || + Op0.getValueType().getVectorElementType() != MVT::i32) return SDValue(); unsigned ExtOpcode = Op0.getOpcode(); + SDValue A = Op0; + SDValue B; + if (ExtOpcode == ISD::MUL) { + A = Op0.getOperand(0); + B = Op0.getOperand(1); + if (A.getOpcode() != B.getOpcode() || + A.getOperand(0).getValueType() != B.getOperand(0).getValueType()) + return SDValue(); + ExtOpcode = A.getOpcode(); + } if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) return SDValue(); - EVT Op0VT = Op0.getOperand(0).getValueType(); + EVT Op0VT = A.getOperand(0).getValueType(); if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8) return SDValue(); SDLoc DL(Op0); - SDValue Ones = DAG.getConstant(1, DL, Op0VT); + // For non-mla reductions B can be set to 1. For MLA we take the operand of + // the extend B. + if (!B) + B = DAG.getConstant(1, DL, Op0VT); + else + B = B.getOperand(0); + SDValue Zeros = DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32); auto DotOpcode = (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT; SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, - Ones, Op0.getOperand(0)); + A.getOperand(0), B); return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll --- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll +++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll @@ -9,11 +9,10 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: ldr d1, [x1] -; CHECK-NEXT: dup v2.2s, wzr +; CHECK-NEXT: movi v2.2d, #0000000000000000 ; CHECK-NEXT: udot v2.2s, v1.8b, v0.8b ; CHECK-NEXT: addp v0.2s, v2.2s, v2.2s -; CHECK-NEXT: fmov x0, d0 -; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0 +; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: %0 = bitcast i8* %a to <8 x i8>* @@ -33,7 +32,7 @@ ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: movi v1.2d, #0000000000000000 ; CHECK-NEXT: movi v2.8b, #1 -; CHECK-NEXT: udot v1.2s, v2.8b, v0.8b +; CHECK-NEXT: udot v1.2s, v0.8b, v2.8b ; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -50,11 +49,10 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: ldr d1, [x1] -; CHECK-NEXT: dup v2.2s, wzr +; CHECK-NEXT: movi v2.2d, #0000000000000000 ; CHECK-NEXT: sdot v2.2s, v1.8b, v0.8b ; CHECK-NEXT: addp v0.2s, v2.2s, v2.2s -; CHECK-NEXT: fmov x0, d0 -; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0 +; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: %0 = bitcast i8* %a to <8 x i8>* @@ -74,7 +72,7 @@ ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: movi v1.2d, #0000000000000000 ; CHECK-NEXT: movi v2.8b, #1 -; CHECK-NEXT: sdot v1.2s, v2.8b, v0.8b +; CHECK-NEXT: sdot v1.2s, v0.8b, v2.8b ; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -92,7 +90,7 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: ldr q1, [x1] -; CHECK-NEXT: dup v2.4s, wzr +; CHECK-NEXT: movi v2.2d, #0000000000000000 ; CHECK-NEXT: udot v2.4s, v1.16b, v0.16b ; CHECK-NEXT: addv s0, v2.4s ; CHECK-NEXT: fmov w8, s0 @@ -117,7 +115,7 @@ ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: movi v1.16b, #1 ; CHECK-NEXT: movi v2.2d, #0000000000000000 -; CHECK-NEXT: udot v2.4s, v1.16b, v0.16b +; CHECK-NEXT: udot v2.4s, v0.16b, v1.16b ; CHECK-NEXT: addv s0, v2.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -134,7 +132,7 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: ldr q1, [x1] -; CHECK-NEXT: dup v2.4s, wzr +; CHECK-NEXT: movi v2.2d, #0000000000000000 ; CHECK-NEXT: sdot v2.4s, v1.16b, v0.16b ; CHECK-NEXT: addv s0, v2.4s ; CHECK-NEXT: fmov w8, s0 @@ -159,7 +157,7 @@ ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: movi v1.16b, #1 ; CHECK-NEXT: movi v2.2d, #0000000000000000 -; CHECK-NEXT: sdot v2.4s, v1.16b, v0.16b +; CHECK-NEXT: sdot v2.4s, v0.16b, v1.16b ; CHECK-NEXT: addv s0, v2.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -175,20 +173,10 @@ define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) { ; CHECK-LABEL: test_udot_v8i8_double: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: ushll v2.8h, v2.8b, #0 -; CHECK-NEXT: ushll v3.8h, v3.8b, #0 -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: ext v5.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: ext v1.16b, v2.16b, v2.16b, #8 -; CHECK-NEXT: umull v2.4s, v2.4h, v3.4h -; CHECK-NEXT: ext v3.16b, v3.16b, v3.16b, #8 -; CHECK-NEXT: umlal v0.4s, v4.4h, v5.4h -; CHECK-NEXT: umlal v2.4s, v1.4h, v3.4h -; CHECK-NEXT: add v0.4s, v0.4s, v2.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEXT: udot v4.2s, v2.8b, v3.8b +; CHECK-NEXT: udot v4.2s, v0.8b, v1.8b +; CHECK-NEXT: addp v0.2s, v4.2s, v4.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -209,8 +197,8 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v1.2d, #0000000000000000 ; CHECK-NEXT: movi v3.8b, #1 -; CHECK-NEXT: udot v1.2s, v3.8b, v2.8b -; CHECK-NEXT: udot v1.2s, v3.8b, v0.8b +; CHECK-NEXT: udot v1.2s, v2.8b, v3.8b +; CHECK-NEXT: udot v1.2s, v0.8b, v3.8b ; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -226,30 +214,10 @@ define i32 @test_udot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) { ; CHECK-LABEL: test_udot_v16i8_double: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll2 v4.8h, v0.16b, #0 -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll2 v5.8h, v1.16b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8 -; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8 -; CHECK-NEXT: umull2 v16.4s, v0.8h, v1.8h -; CHECK-NEXT: umlal v16.4s, v6.4h, v7.4h -; CHECK-NEXT: ushll2 v6.8h, v2.16b, #0 -; CHECK-NEXT: ushll v2.8h, v2.8b, #0 -; CHECK-NEXT: ushll2 v7.8h, v3.16b, #0 -; CHECK-NEXT: ushll v3.8h, v3.8b, #0 -; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: ext v1.16b, v6.16b, v6.16b, #8 -; CHECK-NEXT: umlal v0.4s, v4.4h, v5.4h -; CHECK-NEXT: ext v4.16b, v7.16b, v7.16b, #8 -; CHECK-NEXT: umull v5.4s, v2.4h, v3.4h -; CHECK-NEXT: umull2 v2.4s, v2.8h, v3.8h -; CHECK-NEXT: umlal v2.4s, v1.4h, v4.4h -; CHECK-NEXT: umlal v5.4s, v6.4h, v7.4h -; CHECK-NEXT: add v0.4s, v0.4s, v16.4s -; CHECK-NEXT: add v1.4s, v5.4s, v2.4s -; CHECK-NEXT: add v0.4s, v0.4s, v1.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEXT: udot v4.4s, v2.16b, v3.16b +; CHECK-NEXT: udot v4.4s, v0.16b, v1.16b +; CHECK-NEXT: addv s0, v4.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -270,8 +238,8 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v1.16b, #1 ; CHECK-NEXT: movi v3.2d, #0000000000000000 -; CHECK-NEXT: udot v3.4s, v1.16b, v2.16b -; CHECK-NEXT: udot v3.4s, v1.16b, v0.16b +; CHECK-NEXT: udot v3.4s, v2.16b, v1.16b +; CHECK-NEXT: udot v3.4s, v0.16b, v1.16b ; CHECK-NEXT: addv s0, v3.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -287,20 +255,10 @@ define i32 @test_sdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) { ; CHECK-LABEL: test_sdot_v8i8_double: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.8h, v1.8b, #0 -; CHECK-NEXT: sshll v2.8h, v2.8b, #0 -; CHECK-NEXT: sshll v3.8h, v3.8b, #0 -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: ext v5.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: ext v1.16b, v2.16b, v2.16b, #8 -; CHECK-NEXT: smull v2.4s, v2.4h, v3.4h -; CHECK-NEXT: ext v3.16b, v3.16b, v3.16b, #8 -; CHECK-NEXT: smlal v0.4s, v4.4h, v5.4h -; CHECK-NEXT: smlal v2.4s, v1.4h, v3.4h -; CHECK-NEXT: add v0.4s, v0.4s, v2.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEXT: sdot v4.2s, v2.8b, v3.8b +; CHECK-NEXT: sdot v4.2s, v0.8b, v1.8b +; CHECK-NEXT: addp v0.2s, v4.2s, v4.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -321,8 +279,8 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v1.2d, #0000000000000000 ; CHECK-NEXT: movi v3.8b, #1 -; CHECK-NEXT: sdot v1.2s, v3.8b, v2.8b -; CHECK-NEXT: sdot v1.2s, v3.8b, v0.8b +; CHECK-NEXT: sdot v1.2s, v2.8b, v3.8b +; CHECK-NEXT: sdot v1.2s, v0.8b, v3.8b ; CHECK-NEXT: addp v0.2s, v1.2s, v1.2s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -338,30 +296,10 @@ define i32 @test_sdot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) { ; CHECK-LABEL: test_sdot_v16i8_double: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll2 v4.8h, v0.16b, #0 -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll2 v5.8h, v1.16b, #0 -; CHECK-NEXT: sshll v1.8h, v1.8b, #0 -; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8 -; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8 -; CHECK-NEXT: smull2 v16.4s, v0.8h, v1.8h -; CHECK-NEXT: smlal v16.4s, v6.4h, v7.4h -; CHECK-NEXT: sshll2 v6.8h, v2.16b, #0 -; CHECK-NEXT: sshll v2.8h, v2.8b, #0 -; CHECK-NEXT: sshll2 v7.8h, v3.16b, #0 -; CHECK-NEXT: sshll v3.8h, v3.8b, #0 -; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: ext v1.16b, v6.16b, v6.16b, #8 -; CHECK-NEXT: smlal v0.4s, v4.4h, v5.4h -; CHECK-NEXT: ext v4.16b, v7.16b, v7.16b, #8 -; CHECK-NEXT: smull v5.4s, v2.4h, v3.4h -; CHECK-NEXT: smull2 v2.4s, v2.8h, v3.8h -; CHECK-NEXT: smlal v2.4s, v1.4h, v4.4h -; CHECK-NEXT: smlal v5.4s, v6.4h, v7.4h -; CHECK-NEXT: add v0.4s, v0.4s, v16.4s -; CHECK-NEXT: add v1.4s, v5.4s, v2.4s -; CHECK-NEXT: add v0.4s, v0.4s, v1.4s -; CHECK-NEXT: addv s0, v0.4s +; CHECK-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEXT: sdot v4.4s, v2.16b, v3.16b +; CHECK-NEXT: sdot v4.4s, v0.16b, v1.16b +; CHECK-NEXT: addv s0, v4.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret entry: @@ -382,8 +320,8 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v1.16b, #1 ; CHECK-NEXT: movi v3.2d, #0000000000000000 -; CHECK-NEXT: sdot v3.4s, v1.16b, v2.16b -; CHECK-NEXT: sdot v3.4s, v1.16b, v0.16b +; CHECK-NEXT: sdot v3.4s, v2.16b, v1.16b +; CHECK-NEXT: sdot v3.4s, v0.16b, v1.16b ; CHECK-NEXT: addv s0, v3.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret