diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -175,6 +175,9 @@ FMADD_XA, FMSUB, FNMSUB, + + // X86 VNNI + DPWSSD, }; } // end namespace llvm diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -1223,6 +1223,13 @@ SmallVectorImpl &DelInstrs, DenseMap &InstIdxForVirtReg) const; + /// When calculate the latency of the root instruction, accumulate the + /// latency of the sequence to the root latency. + /// \param Root - Instruction that could be combined with one of its operands + virtual bool accumulateInstrSeqToRootLatency(MachineInstr &Root) const { + return true; + } + /// Attempt to reassociate \P Root and \P Prev according to \P Pattern to /// reduce critical path length. void reassociateOps(MachineInstr &Root, MachineInstr &Prev, diff --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp --- a/llvm/lib/CodeGen/MachineCombiner.cpp +++ b/llvm/lib/CodeGen/MachineCombiner.cpp @@ -91,7 +91,8 @@ private: bool combineInstructions(MachineBasicBlock *); - MachineInstr *getOperandDef(const MachineOperand &MO); + MachineInstr *getOperandDef(const MachineOperand &MO, + SmallVectorImpl &InsInstrs); bool isTransientMI(const MachineInstr *MI); unsigned getDepth(SmallVectorImpl &InsInstrs, DenseMap &InstrIdxForVirtReg, @@ -149,11 +150,27 @@ MachineFunctionPass::getAnalysisUsage(AU); } -MachineInstr *MachineCombiner::getOperandDef(const MachineOperand &MO) { +MachineInstr * +MachineCombiner::getOperandDef(const MachineOperand &MO, + SmallVectorImpl &InsInstrs) { MachineInstr *DefInstr = nullptr; // We need a virtual register definition. if (MO.isReg() && MO.getReg().isVirtual()) DefInstr = MRI->getUniqueVRegDef(MO.getReg()); + // It is possible that the register is defined in new instructions. + if (!DefInstr) { + for (auto *MI : InsInstrs) { + for (const MachineOperand &DefMO : MI->operands()) { + if (!(DefMO.isReg() && DefMO.getReg().isVirtual())) + continue; + if (!DefMO.isDef()) + continue; + if (DefMO.getReg() != MO.getReg()) + continue; + DefInstr = MI; + } + } + } // PHI's have no depth etc. if (DefInstr && DefInstr->isPHI()) DefInstr = nullptr; @@ -238,7 +255,7 @@ LatencyOp = TSchedModel.computeOperandLatency(DefInstr, DefIdx, InstrPtr, UseIdx); } else { - MachineInstr *DefInstr = getOperandDef(MO); + MachineInstr *DefInstr = getOperandDef(MO, InsInstrs); if (DefInstr && (TII->getMachineCombinerTraceStrategy() != MachineTraceStrategy::TS_Local || DefInstr->getParent() == &MBB)) { @@ -404,8 +421,13 @@ // Account for the latency of the inserted and deleted instructions by unsigned NewRootLatency, RootLatency; - std::tie(NewRootLatency, RootLatency) = - getLatenciesForInstrSequences(*Root, InsInstrs, DelInstrs, BlockTrace); + if (TII->accumulateInstrSeqToRootLatency(*Root)) + std::tie(NewRootLatency, RootLatency) = + getLatenciesForInstrSequences(*Root, InsInstrs, DelInstrs, BlockTrace); + else { + NewRootLatency = TSchedModel.computeInstrLatency(InsInstrs.back()); + RootLatency = TSchedModel.computeInstrLatency(Root); + } unsigned RootSlack = BlockTrace.getInstrSlack(*Root); unsigned NewCycleCount = NewRootDepth + NewRootLatency; diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -602,6 +602,26 @@ std::optional isCopyInstrImpl(const MachineInstr &MI) const override; + bool + getMachineCombinerPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns, + bool DoRegPressureReduce) const override; + + void genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const override; + + /// When calculate the latency of the root instruction, accumulate the + /// latency of the sequence to the root latency. + /// \param Root - Instruction that could be combined with one of its operands + /// For X86 instruction (vpmaddwd + vpmaddwd) -> vpdpwssd, the vpmaddwd + /// is not in the critical path, so the root latency only include vpmaddwd. + bool accumulateInstrSeqToRootLatency(MachineInstr &Root) const override { + return false; + } + private: /// This is a helper for convertToThreeAddress for 8 and 16-bit instructions. /// We use 32-bit LEA to form 3-address code by promoting to a 32-bit diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -22,6 +22,7 @@ #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/LivePhysRegs.h" #include "llvm/CodeGen/LiveVariables.h" +#include "llvm/CodeGen/MachineCombinerPattern.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -9749,5 +9750,79 @@ return It; } +bool X86InstrInfo::getMachineCombinerPatterns( + MachineInstr &Root, SmallVectorImpl &Patterns, + bool DoRegPressureReduce) const { + unsigned Opc = Root.getOpcode(); + switch (Opc) { + default: + return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, + DoRegPressureReduce); + case X86::VPDPWSSDYrr: + case X86::VPDPWSSDYrm: + Patterns.push_back(MachineCombinerPattern::DPWSSD); + return true; + } +} + +void X86InstrInfo::genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const { + MachineFunction *MF = Root.getMF(); + // MachineRegisterInfo &MRI = MF->getRegInfo(); + MachineRegisterInfo &RegInfo = MF->getRegInfo(); + + switch (Pattern) { + default: + // Reassociate instructions. + TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs, + DelInstrs, InstrIdxForVirtReg); + return; + case MachineCombinerPattern::DPWSSD: + unsigned Opc = Root.getOpcode(); + switch (Opc) { + default: + break; + // vpdpwssd ymm2,ymm3,YMMWORD PTR [r8+0x20] + // --> + // vpmaddwd ymm3,ymm3,YMMWORD PTR [r8+0x20] + // vpaddd ymm2,ymm2,ymm3 + case X86::VPDPWSSDYrr: + case X86::VPDPWSSDYrm: + unsigned NewOpc; + switch (Opc) { + case X86::VPDPWSSDYrr: + NewOpc = X86::VPMADDWDYrr; + break; + case X86::VPDPWSSDYrm: + NewOpc = X86::VPMADDWDYrm; + break; + } + // Create vpmaddwd. + auto *RC = RegInfo.getRegClass(Root.getOperand(0).getReg()); + Register NewReg = RegInfo.createVirtualRegister(RC); + MachineInstr *VpMadd = Root.getMF()->CloneMachineInstr(&Root); + VpMadd->setDesc(get(NewOpc)); + VpMadd->untieRegOperand(1); + VpMadd->removeOperand(1); + VpMadd->getOperand(0).setReg(NewReg); + // Create vpaddd. + Register DstReg = Root.getOperand(0).getReg(); + bool IsKill = Root.getOperand(1).isKill(); + MachineInstr *VpAdd = + BuildMI(*MF, MIMetadata(Root), get(X86::VPADDDYrr), DstReg) + .addReg(Root.getOperand(1).getReg(), getKillRegState(IsKill)) + .addReg(VpMadd->getOperand(0).getReg(), getKillRegState(true)); + InstrIdxForVirtReg.insert(std::make_pair(DstReg, 0)); + InsInstrs.push_back(VpMadd); + InsInstrs.push_back(VpAdd); + DelInstrs.push_back(&Root); + break; + } + } +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc" diff --git a/llvm/test/CodeGen/X86/avxvnni-combine.ll b/llvm/test/CodeGen/X86/avxvnni-combine.ll --- a/llvm/test/CodeGen/X86/avxvnni-combine.ll +++ b/llvm/test/CodeGen/X86/avxvnni-combine.ll @@ -31,9 +31,12 @@ ; CHECK-NEXT: .p2align 4, 0x90 ; CHECK-NEXT: .LBB0_8: # =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: {vex} vpdpwssd -96(%rdi), %ymm1, %ymm0 -; CHECK-NEXT: {vex} vpdpwssd -64(%rdi), %ymm1, %ymm0 -; CHECK-NEXT: {vex} vpdpwssd -32(%rdi), %ymm1, %ymm0 -; CHECK-NEXT: {vex} vpdpwssd (%rdi), %ymm1, %ymm0 +; CHECK-NEXT: vpmaddwd -64(%rdi), %ymm1, %ymm2 +; CHECK-NEXT: vpmaddwd -32(%rdi), %ymm1, %ymm3 +; CHECK-NEXT: vpaddd %ymm2, %ymm0, %ymm0 +; CHECK-NEXT: vpaddd %ymm3, %ymm0, %ymm0 +; CHECK-NEXT: vpmaddwd (%rdi), %ymm1, %ymm2 +; CHECK-NEXT: vpaddd %ymm2, %ymm0, %ymm0 ; CHECK-NEXT: addq $4, %rcx ; CHECK-NEXT: subq $-128, %rdi ; CHECK-NEXT: cmpq %rcx, %rdx