Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -770,6 +770,7 @@ setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::GlobalAddress); @@ -10940,6 +10941,47 @@ return SDValue(); } +// VECREDUCE_ADD( EXTEND(v16i8_type) ) to +// VECREDUCE_ADD( DOTv16i8(v16i8_type) ) +static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *ST) { + SDValue Op0 = N->getOperand(0); + if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) { + return SDValue(); + } + + EVT VT = Op0.getValueType(); + if (VT.isScalableVector() || VT.getVectorElementType() != MVT::i32) { + return SDValue(); + } + + EVT Op0VT = Op0.getOperand(0).getValueType(); + unsigned ExtOpcode = Op0.getOpcode(); + if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) { + return SDValue(); + } + + if (Op0VT.isScalableVector() || Op0VT.getVectorElementType() != MVT::i8 || + Op0VT.getVectorNumElements() != 16) { + return SDValue(); + } + + SDValue Ones = + DAG.getConstant(1, SDLoc(Op0), Op0.getOperand(0).getValueType()); + SDValue Zeros = DAG.getConstant( + 0, SDLoc(Op0), EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4)); + + auto DotOpcode = + (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64::UDOTv16i8 : AArch64::SDOTv16i8; + MachineSDNode *ABD = + DAG.getMachineNode(DotOpcode, SDLoc(Op0), Zeros.getValueType(), Zeros, + Ones, Op0.getOperand(0)); + SDValue FinalABD = DAG.getNode(ISD::VECREDUCE_ADD, SDLoc(N), + N->getValueType(0), SDValue(ABD, 0)); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), FinalABD); + return FinalABD; +} + static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -14622,6 +14664,8 @@ return performPostLD1Combine(N, DCI, true); case ISD::EXTRACT_VECTOR_ELT: return performExtractVectorEltCombine(N, DAG); + case ISD::VECREDUCE_ADD: + return performVecReduceAddCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { Index: llvm/test/CodeGen/AArch64/neon-dot-product.ll =================================================================== --- llvm/test/CodeGen/AArch64/neon-dot-product.ll +++ llvm/test/CodeGen/AArch64/neon-dot-product.ll @@ -255,6 +255,17 @@ ret i32 %op.extra } +define i32 @test_udot_v16i8_2(i8* nocapture readonly %a1) { +entry: +; CHECK-LABEL: test_udot_v16i8_2: +; CHECK: udot {{v[0-9]+}}.4s, {{v[0-9]+}}.16b, {{v[0-9]+}}.16b + %0 = bitcast i8* %a1 to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %2) + ret i32 %3 +} + define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %sum) { entry: ; CHECK-LABEL: test_sdot_v16i8: @@ -270,3 +281,14 @@ %op.extra = add nsw i32 %7, %sum ret i32 %op.extra } + +define i32 @test_sdot_v16i8_2(i8* nocapture readonly %a1) { +entry: +; CHECK-LABEL: test_sdot_v16i8_2: +; CHECK: sdot {{v[0-9]+}}.4s, {{v[0-9]+}}.16b, {{v[0-9]+}}.16b + %0 = bitcast i8* %a1 to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = sext <16 x i8> %1 to <16 x i32> + %3 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %2) + ret i32 %3 +}