diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -770,6 +770,7 @@ setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::GlobalAddress); @@ -10940,6 +10941,40 @@ return SDValue(); } +// VECREDUCE_ADD( ZERO_EXTEND(v16i8_type) ) to +// VECREDUCE_ADD( UDOTv16i8(v16i8_type) ) +static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *ST) { + SDValue Op0 = N->getOperand(0); + if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) { + return SDValue(); + } + + EVT VT = Op0.getValueType(); + if (VT.isScalableVector() || VT.getVectorElementType() != MVT::i32) { + return SDValue(); + } + + EVT Op0VT = Op0.getOperand(0).getValueType(); + if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0VT.isScalableVector() || + Op0VT.getVectorElementType() != MVT::i8 || + Op0VT.getVectorNumElements() != 16) { + return SDValue(); + } + + SDValue Ones = + DAG.getConstant(1, SDLoc(Op0), Op0.getOperand(0).getValueType()); + SDValue Zeros = DAG.getConstant( + 0, SDLoc(Op0), EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4)); + MachineSDNode *ABD = + DAG.getMachineNode(AArch64::UDOTv16i8, SDLoc(Op0), Zeros.getValueType(), + Zeros, Ones, Op0.getOperand(0)); + SDValue FinalABD = DAG.getNode(ISD::VECREDUCE_ADD, SDLoc(N), + N->getValueType(0), SDValue(ABD, 0)); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), FinalABD); + return FinalABD; +} + static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -14622,6 +14657,8 @@ return performPostLD1Combine(N, DCI, true); case ISD::EXTRACT_VECTOR_ELT: return performExtractVectorEltCombine(N, DAG); + case ISD::VECREDUCE_ADD: + return performVecReduceAddCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { diff --git a/llvm/test/CodeGen/AArch64/neon-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-dot-product.ll --- a/llvm/test/CodeGen/AArch64/neon-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-dot-product.ll @@ -255,6 +255,17 @@ ret i32 %op.extra } +define i32 @test_udot_v16i8_2(i8* nocapture readonly %a1) { +entry: +; CHECK-LABEL: test_udot_v16i8_2: +; CHECK: udot {{v[0-9]+}}.4s, {{v[0-9]+}}.16b, {{v[0-9]+}}.16b + %0 = bitcast i8* %a1 to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %2) + ret i32 %3 +} + define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %sum) { entry: ; CHECK-LABEL: test_sdot_v16i8: