diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -41799,6 +41799,40 @@ return SDValue(); } +// (mul (zext a), (sext, b)) +static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0, + SDValue &Op1) { + Op0 = Mul.getOperand(0); + Op1 = Mul.getOperand(1); + + // The operand1 should be signed extend + if (Op0.getOpcode() == ISD::SIGN_EXTEND) + std::swap(Op0, Op1); + + if (Op0.getOpcode() != ISD::ZERO_EXTEND) + return false; + + auto IsFreeTruncation = [](SDValue &Op) -> bool { + if ((Op.getOpcode() == ISD::ZERO_EXTEND || + Op.getOpcode() == ISD::SIGN_EXTEND) && + Op.getOperand(0).getScalarValueSizeInBits() <= 8) + return true; + + // TODO: Support contant value. + return false; + }; + + // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned + // value, we need to check Op0 is zero extended value. Op1 should be signed + // value, so we just check the signed bits. + if ((IsFreeTruncation(Op0) && + DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8) && + (IsFreeTruncation(Op1) && DAG.ComputeMaxSignificantBits(Op1) <= 8)) + return true; + + return false; +} + // Given a ABS node, detect the following pattern: // (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))). // This is useful as it is the input into a SAD pattern. @@ -41820,6 +41854,50 @@ return true; } +static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS, + unsigned &LogBias, const SDLoc &DL, + const X86Subtarget &Subtarget) { + // Extend or truncate to MVT::i8 first. + MVT Vi8VT = + MVT::getVectorVT(MVT::i8, LHS.getValueType().getVectorElementCount()); + LHS = DAG.getZExtOrTrunc(LHS, DL, Vi8VT); + RHS = DAG.getSExtOrTrunc(RHS, DL, Vi8VT); + + // VPDPBUSD(<16 x i32>C, <16 x i8>A, <16 x i8>B). For each dst element + // C[0] = C[0] + A[0]B[0] + A[1]B[1] + A[2]B[2] + A[3]B[3]. + // The src A, B element type is i8, but the dst C element type is i32. + // When we calculate the reduce stage, we use src vector type vXi8 for it + // so we need logbias 2 to avoid extra 2 stages. + LogBias = 2; + + unsigned RegSize = std::max(128u, (unsigned)Vi8VT.getSizeInBits()); + if (Subtarget.hasVNNI() && !Subtarget.hasVLX()) + RegSize = std::max(512u, RegSize); + + // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we + // fill in the missing vector elements with 0. + unsigned NumConcat = RegSize / Vi8VT.getSizeInBits(); + SmallVector Ops(NumConcat, DAG.getConstant(0, DL, Vi8VT)); + Ops[0] = LHS; + MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8); + SDValue DpOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + Ops[0] = RHS; + SDValue DpOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops); + + // Actually build the DotProduct, split as 256/512 bits for + // AVXVNNI/AVX512VNNI. + auto DpBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32); + return DAG.getNode(X86ISD::VPDPBUSD, DL, VT, Ops); + }; + MVT DpVT = MVT::getVectorVT(MVT::i32, RegSize / 32); + SDValue Zero = DAG.getConstant(0, DL, DpVT); + + return SplitOpsAndApply(DAG, Subtarget, DL, DpVT, {Zero, DpOp0, DpOp1}, + DpBuilder, false); +} + // Given two zexts of to , create a PSADBW of the inputs // to these zexts. static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0, @@ -42069,6 +42147,77 @@ return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext); } +static SDValue combineVPDPBUSDPattern(SDNode *Extract, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + if (!Subtarget.hasVNNI() && !Subtarget.hasAVXVNNI()) + return SDValue(); + + EVT ExtractVT = Extract->getValueType(0); + // Verify the type we're extracting is i32, as the output element type of + // vpdpbusd is i32. + if (ExtractVT != MVT::i32) + return SDValue(); + + EVT VT = Extract->getOperand(0).getValueType(); + if (!isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); + + // Match shuffle + add pyramid. + ISD::NodeType BinOp; + SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD}); + + // We can't combine to vpdpbusd for zext, because each of the 4 multiplies + // done by vpdpbusd compute a signed 16-bit product that will be sign extended + // before adding into the accumulator. + // TODO: + // We also need to verify that the multiply has at least 2x the number of bits + // of the input. We shouldn't match + // (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). + // if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND)) + // Root = Root.getOperand(0); + + // If there was a match, we want Root to be a mul. + if (!Root || Root.getOpcode() != ISD::MUL) + return SDValue(); + + // Check whether we have an extend and mul pattern + SDValue LHS, RHS; + if (!detectExtMul(DAG, Root, LHS, RHS)) + return SDValue(); + + // Create the dot product instruction. + SDLoc DL(Extract); + unsigned StageBias; + SDValue DP = createVPDPBUSD(DAG, LHS, RHS, StageBias, DL, Subtarget); + + // If the original vector was wider than 4 elements, sum over the results + // in the DP vector. + unsigned Stages = Log2_32(VT.getVectorNumElements()); + EVT DpVT = DP.getValueType(); + + if (Stages > StageBias) { + unsigned DpElems = DpVT.getVectorNumElements(); + + for (unsigned i = Stages - StageBias; i > 0; --i) { + SmallVector Mask(DpElems, -1); + for (unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j) + Mask[j] = MaskEnd + j; + + SDValue Shuffle = + DAG.getVectorShuffle(DpVT, DL, DP, DAG.getUNDEF(DpVT), Mask); + DP = DAG.getNode(ISD::ADD, DL, DpVT, DP, Shuffle); + } + } + + // Return the lowest ExtractSizeInBits bits. + EVT ResVT = + EVT::getVectorVT(*DAG.getContext(), ExtractVT, + DpVT.getSizeInBits() / ExtractVT.getSizeInBits()); + DP = DAG.getBitcast(ResVT, DP); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, DP, + Extract->getOperand(1)); +} + static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG, const X86Subtarget &Subtarget) { // PSADBW is only supported on SSE2 and up. @@ -42676,6 +42825,9 @@ if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) return SAD; + if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) + return VPDPBUSD; + // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK. if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) return Cmp; diff --git a/llvm/lib/Target/X86/X86PartialReduction.cpp b/llvm/lib/Target/X86/X86PartialReduction.cpp --- a/llvm/lib/Target/X86/X86PartialReduction.cpp +++ b/llvm/lib/Target/X86/X86PartialReduction.cpp @@ -13,15 +13,16 @@ //===----------------------------------------------------------------------===// #include "X86.h" +#include "X86TargetMachine.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsX86.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Pass.h" -#include "X86TargetMachine.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -49,7 +50,7 @@ } private: - bool tryMAddReplacement(Instruction *Op); + bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB); bool trySADReplacement(Instruction *Op); }; } @@ -63,7 +64,46 @@ INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, "X86 Partial Reduction", false, false) -bool X86PartialReduction::tryMAddReplacement(Instruction *Op) { +// This function should be aligned with detectExtMul() in X86ISelLowering.cpp. +static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul, + const DataLayout *DL) { + if (!ST->hasVNNI() && !ST->hasAVXVNNI()) + return false; + + Value *LHS = Mul->getOperand(0); + Value *RHS = Mul->getOperand(1); + + if (isa(LHS)) + std::swap(LHS, RHS); + + if (!isa(LHS)) + return false; + + auto IsFreeTruncation = [&](Value *Op) { + if (auto *Cast = dyn_cast(Op)) { + if (Cast->getParent() == Mul->getParent() && + (Cast->getOpcode() == Instruction::SExt || + Cast->getOpcode() == Instruction::ZExt) && + Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8) + return true; + } + // TODO: Support constant in ISel. + return false; + }; + + // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned + // value, we need to check LHS is zero extended value. RHS should be signed + // value, so we just check the signed bits. + if ((IsFreeTruncation(LHS) && + computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) && + (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8)) + return true; + + return false; +} + +bool X86PartialReduction::tryMAddReplacement(Instruction *Op, + bool ReduceInOneBB) { if (!ST->hasSSE2()) return false; @@ -82,6 +122,13 @@ Value *LHS = Mul->getOperand(0); Value *RHS = Mul->getOperand(1); + // If the target support VNNI, leave it to ISel to combine reduce operation + // to VNNI instruction. + // TODO: we can support transforming reduce to VNNI intrinsic for across block + // in this pass. + if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL)) + return false; + // LHS and RHS should be only used once or if they are the same then only // used twice. Only check this when SSE4.1 is enabled and we have zext/sext // instructions, otherwise we use punpck to emulate zero extend in stages. The @@ -300,7 +347,9 @@ // Walk backwards from the ExtractElementInst and determine if it is the end of // a horizontal reduction. Return the input to the reduction if we find one. -static Value *matchAddReduction(const ExtractElementInst &EE) { +static Value *matchAddReduction(const ExtractElementInst &EE, + bool &ReduceInOneBB) { + ReduceInOneBB = true; // Make sure we're extracting index 0. auto *Index = dyn_cast(EE.getIndexOperand()); if (!Index || !Index->isNullValue()) @@ -309,6 +358,8 @@ const auto *BO = dyn_cast(EE.getVectorOperand()); if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) return nullptr; + if (EE.getParent() != BO->getParent()) + ReduceInOneBB = false; unsigned NumElems = cast(BO->getType())->getNumElements(); // Ensure the reduction size is a power of 2. @@ -321,6 +372,8 @@ const auto *BO = dyn_cast(Op); if (!BO || BO->getOpcode() != Instruction::Add) return nullptr; + if (EE.getParent() != BO->getParent()) + ReduceInOneBB = false; // If this isn't the first add, then it should only have 2 users, the // shuffle and another add which we checked in the previous iteration. @@ -460,9 +513,10 @@ if (!EE) continue; + bool ReduceInOneBB; // First find a reduction tree. // FIXME: Do we need to handle other opcodes than Add? - Value *Root = matchAddReduction(*EE); + Value *Root = matchAddReduction(*EE, ReduceInOneBB); if (!Root) continue; @@ -470,7 +524,7 @@ collectLeaves(Root, Leaves); for (Instruction *I : Leaves) { - if (tryMAddReplacement(I)) { + if (tryMAddReplacement(I, ReduceInOneBB)) { MadeChange = true; continue; } diff --git a/llvm/test/CodeGen/X86/dpbusd.ll b/llvm/test/CodeGen/X86/dpbusd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/dpbusd.ll @@ -0,0 +1,548 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avxvnni | FileCheck %s --check-prefixes=AVXVNNI +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni | FileCheck %s --check-prefixes=AVX512,AVX512VNNI +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni -mattr=+avx512vl | FileCheck %s --check-prefixes=AVX512,AVX512VLVNNI + +define i32 @no_dpbusd(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: no_dpbusd: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVXVNNI-NEXT: vpmovzxbw {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVXVNNI-NEXT: vpmaddwd %ymm0, %ymm1, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: vzeroupper +; AVXVNNI-NEXT: retq +; +; AVX512-LABEL: no_dpbusd: +; AVX512: # %bb.0: # %entry +; AVX512-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVX512-NEXT: vpmovzxbw {{.*#+}} ymm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVX512-NEXT: vpmaddwd %ymm0, %ymm1, %ymm0 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX512-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vmovd %xmm0, %eax +; AVX512-NEXT: addl %edx, %eax +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3, align 16 + %5 = zext <16 x i8> %4 to <16 x i32> + %6 = mul nsw <16 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +define i32 @vpdpbusd_mutate(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_mutate: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovdqa (%rsi), %xmm0 +; AVXVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVXVNNI-NEXT: {vex} vpdpbusd (%rdi), %xmm0, %xmm1 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_mutate: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovdqa (%rdi), %xmm0 +; AVX512VNNI-NEXT: vmovdqa (%rsi), %xmm1 +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm0, %zmm1, %zmm2 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; AVX512VNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VNNI-NEXT: vmovd %xmm0, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_mutate: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovdqa (%rsi), %xmm0 +; AVX512VLVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVX512VLVNNI-NEXT: vpdpbusd (%rdi), %xmm0, %xmm1 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; AVX512VLVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VLVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VLVNNI-NEXT: vmovd %xmm0, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = sext <16 x i8> %1 to <16 x i32> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3, align 16 + %5 = zext <16 x i8> %4 to <16 x i32> + %6 = mul nsw <16 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +define i32 @mul_zext(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: mul_zext: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVXVNNI-NEXT: vpmovsxbw (%rsi), %ymm1 +; AVXVNNI-NEXT: vpmullw %ymm0, %ymm1, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpmovzxwd {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero +; AVXVNNI-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero +; AVXVNNI-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: vzeroupper +; AVXVNNI-NEXT: retq +; +; AVX512-LABEL: mul_zext: +; AVX512: # %bb.0: # %entry +; AVX512-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVX512-NEXT: vpmovsxbw (%rsi), %ymm1 +; AVX512-NEXT: vpmullw %ymm0, %ymm1, %ymm0 +; AVX512-NEXT: vpmovzxwd {{.*#+}} zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero +; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1 +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vmovd %xmm0, %eax +; AVX512-NEXT: addl %edx, %eax +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = zext <16 x i8> %1 to <16 x i16> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3, align 16 + %5 = sext <16 x i8> %4 to <16 x i16> + %6 = mul nsw <16 x i16> %5, %2 + ; We can't combine to vpdpbusd for zext, because each of the 4 multiplies + ; done by vpdpbusd compute a signed 16-bit product that will be sign extended + ; before adding into the accumulator. + %7 = zext <16 x i16> %6 to <16 x i32> + %8 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %7) + %op.extra = add nsw i32 %8, %c + ret i32 %op.extra +} + +define i32 @mul_sext(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: mul_sext: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVXVNNI-NEXT: vpmovsxbw (%rsi), %ymm1 +; AVXVNNI-NEXT: vpmullw %ymm0, %ymm1, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpmovsxwd %xmm1, %ymm1 +; AVXVNNI-NEXT: vpmovsxwd %xmm0, %ymm0 +; AVXVNNI-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: vzeroupper +; AVXVNNI-NEXT: retq +; +; AVX512-LABEL: mul_sext: +; AVX512: # %bb.0: # %entry +; AVX512-NEXT: vpmovzxbw {{.*#+}} ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero +; AVX512-NEXT: vpmovsxbw (%rsi), %ymm1 +; AVX512-NEXT: vpmullw %ymm0, %ymm1, %ymm0 +; AVX512-NEXT: vpmovsxwd %ymm0, %zmm0 +; AVX512-NEXT: vextracti64x4 $1, %zmm0, %ymm1 +; AVX512-NEXT: vpaddd %zmm1, %zmm0, %zmm0 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vmovd %xmm0, %eax +; AVX512-NEXT: addl %edx, %eax +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = zext <16 x i8> %1 to <16 x i16> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3, align 16 + %5 = sext <16 x i8> %4 to <16 x i16> + %6 = mul nsw <16 x i16> %5, %2 + ; TODO: + ; We also need to verify that the multiply has at least 2x the number of bits + ; of the input. We shouldn't match + ; (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))). + %7 = sext <16 x i16> %6 to <16 x i32> + %8 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %7) + %op.extra = add nsw i32 %8, %c + ret i32 %op.extra +} + +define i32 @vpdpbusd_512(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_512: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovdqa (%rdi), %xmm0 +; AVXVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVXVNNI-NEXT: {vex} vpdpbusd (%rsi), %xmm0, %xmm1 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_512: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovdqa (%rdi), %xmm0 +; AVX512VNNI-NEXT: vmovdqa (%rsi), %xmm1 +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm1, %zmm0, %zmm2 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; AVX512VNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VNNI-NEXT: vmovd %xmm0, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_512: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovdqa (%rdi), %xmm0 +; AVX512VLVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVX512VLVNNI-NEXT: vpdpbusd (%rsi), %xmm0, %xmm1 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm1[2,3,2,3] +; AVX512VLVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VLVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VLVNNI-NEXT: vmovd %xmm0, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = bitcast i8* %b to <16 x i8>* + %4 = load <16 x i8>, <16 x i8>* %3, align 16 + %5 = sext <16 x i8> %4 to <16 x i32> + %6 = mul nsw <16 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>) + +define i32 @vpdpbusd_256(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_256: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVXVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVXVNNI-NEXT: {vex} vpdpbusd %xmm0, %xmm1, %xmm2 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_256: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm0, %zmm1, %zmm2 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1] +; AVX512VNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX512VNNI-NEXT: vmovd %xmm0, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_256: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVX512VLVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VLVNNI-NEXT: vpdpbusd %xmm0, %xmm1, %xmm2 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[1,1,1,1] +; AVX512VLVNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX512VLVNNI-NEXT: vmovd %xmm0, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <8 x i8>* + %1 = load <8 x i8>, <8 x i8>* %0, align 8 + %2 = zext <8 x i8> %1 to <8 x i32> + %3 = bitcast i8* %b to <8 x i8>* + %4 = load <8 x i8>, <8 x i8>* %3, align 8 + %5 = sext <8 x i8> %4 to <8 x i32> + %6 = mul nsw <8 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v8i32(<8 x i32>) + +define i32 @vpdpbusd_128(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_128: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVXVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVXVNNI-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3,4,5,6,7] +; AVXVNNI-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7] +; AVXVNNI-NEXT: {vex} vpdpbusd %xmm1, %xmm0, %xmm2 +; AVXVNNI-NEXT: vmovd %xmm2, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_128: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVX512VNNI-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3,4,5,6,7] +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm2 = mem[0],zero +; AVX512VNNI-NEXT: vpblendw {{.*#+}} xmm1 = xmm2[0,1],xmm1[2,3,4,5,6,7] +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm0, %zmm1, %zmm2 +; AVX512VNNI-NEXT: vmovd %xmm2, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_128: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVX512VLVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VLVNNI-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3,4,5,6,7] +; AVX512VLVNNI-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3,4,5,6,7] +; AVX512VLVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VLVNNI-NEXT: vpdpbusd %xmm1, %xmm0, %xmm2 +; AVX512VLVNNI-NEXT: vmovd %xmm2, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <4 x i8>* + %1 = load <4 x i8>, <4 x i8>* %0, align 8 + %2 = zext <4 x i8> %1 to <4 x i32> + %3 = bitcast i8* %b to <4 x i8>* + %4 = load <4 x i8>, <4 x i8>* %3, align 8 + %5 = sext <4 x i8> %4 to <4 x i32> + %6 = mul nsw <4 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>) + +define i32 @vpdpbusd_2xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_2xi32: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVXVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVXVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVXVNNI-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7] +; AVXVNNI-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0],xmm2[1,2,3,4,5,6,7] +; AVXVNNI-NEXT: {vex} vpdpbusd %xmm1, %xmm0, %xmm2 +; AVXVNNI-NEXT: vmovd %xmm2, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_2xi32: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VNNI-NEXT: vmovdqa {{.*#+}} xmm1 = [65535,0,0,0] +; AVX512VNNI-NEXT: vpandq %zmm1, %zmm0, %zmm0 +; AVX512VNNI-NEXT: vmovq {{.*#+}} xmm2 = mem[0],zero +; AVX512VNNI-NEXT: vpandq %zmm1, %zmm2, %zmm1 +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm0, %zmm1, %zmm2 +; AVX512VNNI-NEXT: vmovd %xmm2, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_2xi32: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero +; AVX512VLVNNI-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero +; AVX512VLVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VLVNNI-NEXT: vpblendw {{.*#+}} xmm1 = xmm1[0],xmm2[1,2,3,4,5,6,7] +; AVX512VLVNNI-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0],xmm2[1,2,3,4,5,6,7] +; AVX512VLVNNI-NEXT: vpdpbusd %xmm1, %xmm0, %xmm2 +; AVX512VLVNNI-NEXT: vmovd %xmm2, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <2 x i8>* + %1 = load <2 x i8>, <2 x i8>* %0, align 8 + %2 = zext <2 x i8> %1 to <2 x i32> + %3 = bitcast i8* %b to <2 x i8>* + %4 = load <2 x i8>, <2 x i8>* %3, align 8 + %5 = sext <2 x i8> %4 to <2 x i32> + %6 = mul nsw <2 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v2i32(<2 x i32>) + +define i32 @vpdpbusd_32xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_32xi32: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovdqu (%rdi), %ymm0 +; AVXVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVXVNNI-NEXT: {vex} vpdpbusd (%rsi), %ymm0, %ymm1 +; AVXVNNI-NEXT: vextracti128 $1, %ymm1, %xmm0 +; AVXVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: vzeroupper +; AVXVNNI-NEXT: retq +; +; AVX512VNNI-LABEL: vpdpbusd_32xi32: +; AVX512VNNI: # %bb.0: # %entry +; AVX512VNNI-NEXT: vmovdqu (%rdi), %ymm0 +; AVX512VNNI-NEXT: vmovdqu (%rsi), %ymm1 +; AVX512VNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512VNNI-NEXT: vpdpbusd %zmm1, %zmm0, %zmm2 +; AVX512VNNI-NEXT: vextracti128 $1, %ymm2, %xmm0 +; AVX512VNNI-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VNNI-NEXT: vmovd %xmm0, %eax +; AVX512VNNI-NEXT: addl %edx, %eax +; AVX512VNNI-NEXT: vzeroupper +; AVX512VNNI-NEXT: retq +; +; AVX512VLVNNI-LABEL: vpdpbusd_32xi32: +; AVX512VLVNNI: # %bb.0: # %entry +; AVX512VLVNNI-NEXT: vmovdqu (%rdi), %ymm0 +; AVX512VLVNNI-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVX512VLVNNI-NEXT: vpdpbusd (%rsi), %ymm0, %ymm1 +; AVX512VLVNNI-NEXT: vextracti128 $1, %ymm1, %xmm0 +; AVX512VLVNNI-NEXT: vpaddd %xmm0, %xmm1, %xmm0 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512VLVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VLVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512VLVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512VLVNNI-NEXT: vmovd %xmm0, %eax +; AVX512VLVNNI-NEXT: addl %edx, %eax +; AVX512VLVNNI-NEXT: vzeroupper +; AVX512VLVNNI-NEXT: retq +entry: + %0 = bitcast i8* %a to <32 x i8>* + %1 = load <32 x i8>, <32 x i8>* %0, align 16 + %2 = zext <32 x i8> %1 to <32 x i32> + %3 = bitcast i8* %b to <32 x i8>* + %4 = load <32 x i8>, <32 x i8>* %3, align 16 + %5 = sext <32 x i8> %4 to <32 x i32> + %6 = mul nsw <32 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v32i32(<32 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v32i32(<32 x i32>) + +define i32 @vpdpbusd_64xi32(i8 *%a, i8 *%b, i32 %c, i32 %n) { +; AVXVNNI-LABEL: vpdpbusd_64xi32: +; AVXVNNI: # %bb.0: # %entry +; AVXVNNI-NEXT: vmovdqu (%rdi), %ymm0 +; AVXVNNI-NEXT: vmovdqu 32(%rdi), %ymm1 +; AVXVNNI-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVXVNNI-NEXT: vpxor %xmm3, %xmm3, %xmm3 +; AVXVNNI-NEXT: {vex} vpdpbusd 32(%rsi), %ymm1, %ymm3 +; AVXVNNI-NEXT: {vex} vpdpbusd (%rsi), %ymm0, %ymm2 +; AVXVNNI-NEXT: vpaddd %ymm3, %ymm2, %ymm0 +; AVXVNNI-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVXVNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVXVNNI-NEXT: vmovd %xmm0, %eax +; AVXVNNI-NEXT: addl %edx, %eax +; AVXVNNI-NEXT: vzeroupper +; AVXVNNI-NEXT: retq +; +; AVX512-LABEL: vpdpbusd_64xi32: +; AVX512: # %bb.0: # %entry +; AVX512-NEXT: vmovdqu64 (%rdi), %zmm0 +; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1 +; AVX512-NEXT: vpdpbusd (%rsi), %zmm0, %zmm1 +; AVX512-NEXT: vextracti64x4 $1, %zmm1, %ymm0 +; AVX512-NEXT: vpaddd %zmm0, %zmm1, %zmm0 +; AVX512-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; AVX512-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vmovd %xmm0, %eax +; AVX512-NEXT: addl %edx, %eax +; AVX512-NEXT: vzeroupper +; AVX512-NEXT: retq +entry: + %0 = bitcast i8* %a to <64 x i8>* + %1 = load <64 x i8>, <64 x i8>* %0, align 16 + %2 = zext <64 x i8> %1 to <64 x i32> + %3 = bitcast i8* %b to <64 x i8>* + %4 = load <64 x i8>, <64 x i8>* %3, align 16 + %5 = sext <64 x i8> %4 to <64 x i32> + %6 = mul nsw <64 x i32> %5, %2 + %7 = call i32 @llvm.vector.reduce.add.v64i32(<64 x i32> %6) + %op.extra = add nsw i32 %7, %c + ret i32 %op.extra +} + +declare i32 @llvm.vector.reduce.add.v64i32(<64 x i32>) diff --git a/llvm/test/CodeGen/X86/dpbusd_i4.ll b/llvm/test/CodeGen/X86/dpbusd_i4.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/dpbusd_i4.ll @@ -0,0 +1,131 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni -mattr=+avx512vl | FileCheck %s + +declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>) + +define i32 @mul_i8i8(i8 *%a, <16 x i8> %b, i32 %c) { +; CHECK-LABEL: mul_i8i8: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vmovdqa (%rdi), %xmm1 +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm0, %xmm1, %xmm2 +; CHECK-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %esi, %eax +; CHECK-NEXT: retq +entry: + %0 = bitcast i8* %a to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 16 + %2 = zext <16 x i8> %1 to <16 x i32> + %3 = sext <16 x i8> %b to <16 x i32> + %4 = mul nsw <16 x i32> %2, %3 + %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4) + %op.extra = add nsw i32 %5, %c + ret i32 %op.extra +} + +define i32 @mul_i4i8(<16 x i4> %a, <16 x i8> %b, i32 %c) { +; CHECK-LABEL: mul_i4i8: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm1, %xmm0, %xmm2 +; CHECK-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <16 x i4> %a to <16 x i32> + %1 = sext <16 x i8> %b to <16 x i32> + %2 = mul nsw <16 x i32> %0, %1 + %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2) + %op.extra = add nsw i32 %3, %c + ret i32 %op.extra +} + +define i32 @mul_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) { +; CHECK-LABEL: mul_i4i4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpsllw $4, %xmm1, %xmm1 +; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1 +; CHECK-NEXT: vpsrlw $4, %xmm1, %xmm1 +; CHECK-NEXT: vmovdqa {{.*#+}} xmm2 = [8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8] +; CHECK-NEXT: vpxor %xmm2, %xmm1, %xmm1 +; CHECK-NEXT: vpsubb %xmm2, %xmm1, %xmm1 +; CHECK-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm1, %xmm0, %xmm2 +; CHECK-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <16 x i4> %a to <16 x i32> + %1 = sext <16 x i4> %b to <16 x i32> + %2 = mul nsw <16 x i32> %0, %1 + %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2) + %op.extra = add nsw i32 %3, %c + ret i32 %op.extra +} + +define i32 @mul_sext_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) { +; CHECK-LABEL: mul_sext_i4i4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vpmovzxbw {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero,xmm0[8],zero,xmm0[9],zero,xmm0[10],zero,xmm0[11],zero,xmm0[12],zero,xmm0[13],zero,xmm0[14],zero,xmm0[15],zero +; CHECK-NEXT: vpmovzxbw {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero,xmm1[4],zero,xmm1[5],zero,xmm1[6],zero,xmm1[7],zero,xmm1[8],zero,xmm1[9],zero,xmm1[10],zero,xmm1[11],zero,xmm1[12],zero,xmm1[13],zero,xmm1[14],zero,xmm1[15],zero +; CHECK-NEXT: vpsllw $12, %ymm1, %ymm1 +; CHECK-NEXT: vpsraw $12, %ymm1, %ymm1 +; CHECK-NEXT: vpsllw $12, %ymm0, %ymm0 +; CHECK-NEXT: vpsraw $12, %ymm0, %ymm0 +; CHECK-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vextracti128 $1, %ymm0, %xmm1 +; CHECK-NEXT: vpaddd %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq +entry: + %0 = sext <16 x i4> %a to <16 x i32> + %1 = sext <16 x i4> %b to <16 x i32> + %2 = mul nsw <16 x i32> %0, %1 + %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2) + %op.extra = add nsw i32 %3, %c + ret i32 %op.extra +} + +define i32 @mul_zext_i4i4(<16 x i4> %a, <16 x i4> %b, i32 %c) { +; CHECK-LABEL: mul_zext_i4i4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vmovdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15] +; CHECK-NEXT: vpand %xmm2, %xmm1, %xmm1 +; CHECK-NEXT: vpand %xmm2, %xmm0, %xmm0 +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: vpdpbusd %xmm1, %xmm0, %xmm2 +; CHECK-NEXT: vpshufd {{.*#+}} xmm0 = xmm2[2,3,2,3] +; CHECK-NEXT: vpaddd %xmm0, %xmm2, %xmm0 +; CHECK-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1] +; CHECK-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vmovd %xmm0, %eax +; CHECK-NEXT: addl %edi, %eax +; CHECK-NEXT: retq +entry: + %0 = zext <16 x i4> %a to <16 x i32> + %1 = zext <16 x i4> %b to <16 x i32> + %2 = mul nsw <16 x i32> %0, %1 + %3 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %2) + %op.extra = add nsw i32 %3, %c + ret i32 %op.extra +}