diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -175,6 +175,9 @@ FMADD_XA, FMSUB, FNMSUB, + + // X86 VNNI + DPWSSD, }; } // end namespace llvm diff --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp --- a/llvm/lib/CodeGen/MachineCombiner.cpp +++ b/llvm/lib/CodeGen/MachineCombiner.cpp @@ -91,7 +91,8 @@ private: bool combineInstructions(MachineBasicBlock *); - MachineInstr *getOperandDef(const MachineOperand &MO); + MachineInstr *getOperandDef(const MachineOperand &MO, + SmallVectorImpl &InsInstrs); bool isTransientMI(const MachineInstr *MI); unsigned getDepth(SmallVectorImpl &InsInstrs, DenseMap &InstrIdxForVirtReg, @@ -149,11 +150,27 @@ MachineFunctionPass::getAnalysisUsage(AU); } -MachineInstr *MachineCombiner::getOperandDef(const MachineOperand &MO) { +MachineInstr * +MachineCombiner::getOperandDef(const MachineOperand &MO, + SmallVectorImpl &InsInstrs) { MachineInstr *DefInstr = nullptr; // We need a virtual register definition. if (MO.isReg() && MO.getReg().isVirtual()) DefInstr = MRI->getUniqueVRegDef(MO.getReg()); + // It is possible that the register is defined in new instructions. + if (!DefInstr) { + for (auto *MI : InsInstrs) { + for (const MachineOperand &DefMO : MI->operands()) { + if (!(DefMO.isReg() && DefMO.getReg().isVirtual())) + continue; + if (!DefMO.isDef()) + continue; + if (DefMO.getReg() != MO.getReg()) + continue; + DefInstr = MI; + } + } + } // PHI's have no depth etc. if (DefInstr && DefInstr->isPHI()) DefInstr = nullptr; @@ -238,7 +255,7 @@ LatencyOp = TSchedModel.computeOperandLatency(DefInstr, DefIdx, InstrPtr, UseIdx); } else { - MachineInstr *DefInstr = getOperandDef(MO); + MachineInstr *DefInstr = getOperandDef(MO, InsInstrs); if (DefInstr && (TII->getMachineCombinerTraceStrategy() != MachineTraceStrategy::TS_Local || DefInstr->getParent() == &MBB)) { @@ -403,9 +420,8 @@ // even if the instruction depths (data dependency cycles) become worse. // Account for the latency of the inserted and deleted instructions by - unsigned NewRootLatency, RootLatency; - std::tie(NewRootLatency, RootLatency) = - getLatenciesForInstrSequences(*Root, InsInstrs, DelInstrs, BlockTrace); + unsigned NewRootLatency = TSchedModel.computeInstrLatency(InsInstrs.back()); + unsigned RootLatency = TSchedModel.computeInstrLatency(Root); unsigned RootSlack = BlockTrace.getInstrSlack(*Root); unsigned NewCycleCount = NewRootDepth + NewRootLatency; diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -602,6 +602,16 @@ std::optional isCopyInstrImpl(const MachineInstr &MI) const override; + bool + getMachineCombinerPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns, + bool DoRegPressureReduce) const override; + + void genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const override; private: /// This is a helper for convertToThreeAddress for 8 and 16-bit instructions. /// We use 32-bit LEA to form 3-address code by promoting to a 32-bit diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -22,6 +22,7 @@ #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/LivePhysRegs.h" #include "llvm/CodeGen/LiveVariables.h" +#include "llvm/CodeGen/MachineCombinerPattern.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -9749,5 +9750,79 @@ return It; } +bool X86InstrInfo::getMachineCombinerPatterns( + MachineInstr &Root, SmallVectorImpl &Patterns, + bool DoRegPressureReduce) const { + unsigned Opc = Root.getOpcode(); + switch (Opc) { + default: + return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, + DoRegPressureReduce); + case X86::VPDPWSSDYrr: + case X86::VPDPWSSDYrm: + Patterns.push_back(MachineCombinerPattern::DPWSSD); + return true; + } +} + +void X86InstrInfo::genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const { + MachineFunction *MF = Root.getMF(); + // MachineRegisterInfo &MRI = MF->getRegInfo(); + MachineRegisterInfo &RegInfo = MF->getRegInfo(); + + switch (Pattern) { + default: + // Reassociate instructions. + TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs, + DelInstrs, InstrIdxForVirtReg); + return; + case MachineCombinerPattern::DPWSSD: + unsigned Opc = Root.getOpcode(); + switch (Opc) { + default: + break; + // vpdpwssd ymm2,ymm3,YMMWORD PTR [r8+0x20] + // --> + // vpmaddwd ymm3,ymm3,YMMWORD PTR [r8+0x20] + // vpaddd ymm2,ymm2,ymm3 + case X86::VPDPWSSDYrr: + case X86::VPDPWSSDYrm: + unsigned NewOpc; + switch (Opc) { + case X86::VPDPWSSDYrr: + NewOpc = X86::VPMADDWDYrr; + break; + case X86::VPDPWSSDYrm: + NewOpc = X86::VPMADDWDYrm; + break; + } + // Create vpmaddwd. + auto *RC = RegInfo.getRegClass(Root.getOperand(0).getReg()); + Register NewReg = RegInfo.createVirtualRegister(RC); + MachineInstr *VpMadd = Root.getMF()->CloneMachineInstr(&Root); + VpMadd->setDesc(get(NewOpc)); + VpMadd->untieRegOperand(1); + VpMadd->removeOperand(1); + VpMadd->getOperand(0).setReg(NewReg); + // Create vpaddd. + Register DstReg = Root.getOperand(0).getReg(); + bool IsKill = Root.getOperand(1).isKill(); + MachineInstr *VpAdd = + BuildMI(*MF, MIMetadata(Root), get(X86::VPADDDYrr), DstReg) + .addReg(Root.getOperand(1).getReg(), getKillRegState(IsKill)) + .addReg(VpMadd->getOperand(0).getReg(), getKillRegState(true)); + InstrIdxForVirtReg.insert(std::make_pair(DstReg, 0)); + InsInstrs.push_back(VpMadd); + InsInstrs.push_back(VpAdd); + DelInstrs.push_back(&Root); + break; + } + } +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc" diff --git a/llvm/test/CodeGen/X86/avxvnni-combine.ll b/llvm/test/CodeGen/X86/avxvnni-combine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/avxvnni-combine.ll @@ -0,0 +1,133 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=alderlake | FileCheck %s + +; __m256i foo(int cnt, __m256i c, __m256i b, __m256i *p) { +; for (int i = 0; i < cnt; ++i) { +; __m256i a = p[i]; +; __m256i m = _mm256_madd_epi16 (b, a); +; c = _mm256_add_epi32(m, c); +; } +; return c; +; } + +define dso_local <4 x i64> @foo(i32 %0, <4 x i64> %1, <4 x i64> %2, ptr %3) { +; CHECK-LABEL: foo: +; CHECK: # %bb.0: +; CHECK-NEXT: testl %edi, %edi +; CHECK-NEXT: jle .LBB0_6 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: movl %edi, %edx +; CHECK-NEXT: movl %edx, %eax +; CHECK-NEXT: andl $3, %eax +; CHECK-NEXT: cmpl $4, %edi +; CHECK-NEXT: jae .LBB0_7 +; CHECK-NEXT: # %bb.2: +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: jmp .LBB0_3 +; CHECK-NEXT: .LBB0_7: +; CHECK-NEXT: andl $-4, %edx +; CHECK-NEXT: leaq 96(%rsi), %rdi +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: .p2align 4, 0x90 +; CHECK-NEXT: .LBB0_8: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: {vex} vpdpwssd -96(%rdi), %ymm1, %ymm0 +; CHECK-NEXT: vpmaddwd -64(%rdi), %ymm1, %ymm2 +; CHECK-NEXT: vpmaddwd -32(%rdi), %ymm1, %ymm3 +; CHECK-NEXT: vpaddd %ymm2, %ymm0, %ymm0 +; CHECK-NEXT: vpaddd %ymm3, %ymm0, %ymm0 +; CHECK-NEXT: vpmaddwd (%rdi), %ymm1, %ymm2 +; CHECK-NEXT: vpaddd %ymm2, %ymm0, %ymm0 +; CHECK-NEXT: addq $4, %rcx +; CHECK-NEXT: subq $-128, %rdi +; CHECK-NEXT: cmpq %rcx, %rdx +; CHECK-NEXT: jne .LBB0_8 +; CHECK-NEXT: .LBB0_3: +; CHECK-NEXT: testq %rax, %rax +; CHECK-NEXT: je .LBB0_6 +; CHECK-NEXT: # %bb.4: # %.preheader +; CHECK-NEXT: shlq $5, %rcx +; CHECK-NEXT: addq %rcx, %rsi +; CHECK-NEXT: shlq $5, %rax +; CHECK-NEXT: xorl %ecx, %ecx +; CHECK-NEXT: .p2align 4, 0x90 +; CHECK-NEXT: .LBB0_5: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: {vex} vpdpwssd (%rsi,%rcx), %ymm1, %ymm0 +; CHECK-NEXT: addq $32, %rcx +; CHECK-NEXT: cmpq %rcx, %rax +; CHECK-NEXT: jne .LBB0_5 +; CHECK-NEXT: .LBB0_6: +; CHECK-NEXT: retq + %5 = icmp sgt i32 %0, 0 + br i1 %5, label %6, label %33 + +6: ; preds = %4 + %7 = bitcast <4 x i64> %2 to <16 x i16> + %8 = bitcast <4 x i64> %1 to <8 x i32> + %9 = zext i32 %0 to i64 + %10 = and i64 %9, 3 + %11 = icmp ult i32 %0, 4 + br i1 %11, label %14, label %12 + +12: ; preds = %6 + %13 = and i64 %9, 4294967292 + br label %35 + +14: ; preds = %35, %6 + %15 = phi <8 x i32> [ undef, %6 ], [ %57, %35 ] + %16 = phi i64 [ 0, %6 ], [ %58, %35 ] + %17 = phi <8 x i32> [ %8, %6 ], [ %57, %35 ] + %18 = icmp eq i64 %10, 0 + br i1 %18, label %30, label %19 + +19: ; preds = %14, %19 + %20 = phi i64 [ %27, %19 ], [ %16, %14 ] + %21 = phi <8 x i32> [ %26, %19 ], [ %17, %14 ] + %22 = phi i64 [ %28, %19 ], [ 0, %14 ] + %23 = getelementptr inbounds <4 x i64>, ptr %3, i64 %20 + %24 = load <16 x i16>, ptr %23, align 32 + %25 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %7, <16 x i16> %24) + %26 = add <8 x i32> %25, %21 + %27 = add nuw nsw i64 %20, 1 + %28 = add i64 %22, 1 + %29 = icmp eq i64 %28, %10 + br i1 %29, label %30, label %19 + +30: ; preds = %19, %14 + %31 = phi <8 x i32> [ %15, %14 ], [ %26, %19 ] + %32 = bitcast <8 x i32> %31 to <4 x i64> + br label %33 + +33: ; preds = %30, %4 + %34 = phi <4 x i64> [ %32, %30 ], [ %1, %4 ] + ret <4 x i64> %34 + +35: ; preds = %35, %12 + %36 = phi i64 [ 0, %12 ], [ %58, %35 ] + %37 = phi <8 x i32> [ %8, %12 ], [ %57, %35 ] + %38 = phi i64 [ 0, %12 ], [ %59, %35 ] + %39 = getelementptr inbounds <4 x i64>, ptr %3, i64 %36 + %40 = load <16 x i16>, ptr %39, align 32 + %41 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %7, <16 x i16> %40) + %42 = add <8 x i32> %41, %37 + %43 = or i64 %36, 1 + %44 = getelementptr inbounds <4 x i64>, ptr %3, i64 %43 + %45 = load <16 x i16>, ptr %44, align 32 + %46 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %7, <16 x i16> %45) + %47 = add <8 x i32> %46, %42 + %48 = or i64 %36, 2 + %49 = getelementptr inbounds <4 x i64>, ptr %3, i64 %48 + %50 = load <16 x i16>, ptr %49, align 32 + %51 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %7, <16 x i16> %50) + %52 = add <8 x i32> %51, %47 + %53 = or i64 %36, 3 + %54 = getelementptr inbounds <4 x i64>, ptr %3, i64 %53 + %55 = load <16 x i16>, ptr %54, align 32 + %56 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %7, <16 x i16> %55) + %57 = add <8 x i32> %56, %52 + %58 = add nuw nsw i64 %36, 4 + %59 = add i64 %38, 4 + %60 = icmp eq i64 %59, %13 + br i1 %60, label %14, label %35 +} + +declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>)