diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -790,6 +790,7 @@ setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::GlobalAddress); @@ -10989,6 +10990,37 @@ return SDValue(); } +// VECREDUCE_ADD( EXTEND(v16i8_type) ) to +// VECREDUCE_ADD( DOTv16i8(v16i8_type) ) +static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *ST) { + SDValue Op0 = N->getOperand(0); + if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) + return SDValue(); + + if (Op0.getValueType().getVectorElementType() != MVT::i32) + return SDValue(); + + unsigned ExtOpcode = Op0.getOpcode(); + if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) + return SDValue(); + + EVT Op0VT = Op0.getOperand(0).getValueType(); + if (Op0VT != MVT::v16i8) + return SDValue(); + + SDLoc DL(Op0); + SDValue Ones = DAG.getConstant(1, DL, Op0VT); + SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); + auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND) + ? Intrinsic::aarch64_neon_udot + : Intrinsic::aarch64_neon_sdot; + SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(), + DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros, + Ones, Op0.getOperand(0)); + return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); +} + static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -14671,6 +14703,8 @@ return performPostLD1Combine(N, DCI, true); case ISD::EXTRACT_VECTOR_ELT: return performExtractVectorEltCombine(N, DAG); + case ISD::VECREDUCE_ADD: + return performVecReduceAddCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { diff --git a/llvm/test/CodeGen/AArch64/neon-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-dot-product.ll --- a/llvm/test/CodeGen/AArch64/neon-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-dot-product.ll @@ -255,6 +255,20 @@ ret i32 %op.extra } +define i32 @test_udot_v16i8_2(i8* nocapture readonly %a1) { +; CHECK-LABEL: test_udot_v16i8_2: +; CHECK: movi {{v[0-9]+}}.16b, #1 +; CHECK: movi {{v[0-9]+}}.2d, #0000000000000000 +; CHECK: udot {{v[0-9]+}}.4s, {{v[0-9]+}}.16b, {{v[0-9]+}}.16b +; CHECK: addv s0, {{v[0-9]+}}.4s +entry: + %0 = bitcast i8* %a1 to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %2) + ret i32 %3 +} + define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b, i32 %sum) { entry: ; CHECK-LABEL: test_sdot_v16i8: @@ -270,3 +284,17 @@ %op.extra = add nsw i32 %7, %sum ret i32 %op.extra } + +define i32 @test_sdot_v16i8_2(i8* nocapture readonly %a1) { +; CHECK-LABEL: test_sdot_v16i8_2: +; CHECK: movi {{v[0-9]+}}.16b, #1 +; CHECK: movi {{v[0-9]+}}.2d, #0000000000000000 +; CHECK: sdot {{v[0-9]+}}.4s, {{v[0-9]+}}.16b, {{v[0-9]+}}.16b +; CHECK: addv s0, {{v[0-9]+}}.4s +entry: + %0 = bitcast i8* %a1 to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0 + %2 = sext <16 x i8> %1 to <16 x i32> + %3 = call i32 @llvm.experimental.vector.reduce.add.v16i32(<16 x i32> %2) + ret i32 %3 +}