diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -218,6 +218,10 @@ finalizeInsInstrs(MachineInstr &Root, MachineCombinerPattern &P, SmallVectorImpl &InsInstrs) const override; + bool shouldReduceRegisterPressure( + const MachineBasicBlock *MBB, + const RegisterClassInfo *RegClassInfo) const override; + void genAlternativeCodeSequence( MachineInstr &Root, MachineCombinerPattern Pattern, SmallVectorImpl &InsInstrs, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -26,6 +26,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/MachineTraceMetrics.h" +#include "llvm/CodeGen/RegisterPressure.h" #include "llvm/CodeGen/RegisterScavenging.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/MC/MCInstBuilder.h" @@ -34,6 +35,8 @@ using namespace llvm; +#define DEBUG_TYPE "riscv-instr-info" + #define GEN_CHECK_COMPRESS_INSTR #include "RISCVGenCompressInstEmitter.inc" @@ -1368,6 +1371,46 @@ } } +bool RISCVInstrInfo::shouldReduceRegisterPressure( + const MachineBasicBlock *MBB, const RegisterClassInfo *RegClassInfo) const { + const TargetRegisterInfo *TRI = STI.getRegisterInfo(); + const MachineFunction *MF = MBB->getParent(); + const MachineRegisterInfo *MRI = &MF->getRegInfo(); + + auto GetMBBPressure = + [&](const MachineBasicBlock *MBB) -> std::vector { + RegionPressure Pressure; + RegPressureTracker RPTracker(Pressure); + + // Initialize the register pressure tracker. + RPTracker.init(MBB->getParent(), RegClassInfo, nullptr, MBB, MBB->end(), + /*TrackLaneMasks*/ false, /*TrackUntiedDefs=*/true); + + for (const auto &MI : reverse(*MBB)) { + if (MI.isDebugOrPseudoInstr()) + continue; + RegisterOperands RegOpers; + RegOpers.collect(MI, *TRI, *MRI, false, false); + RPTracker.recedeSkipDebugValues(); + assert(&*RPTracker.getPos() == &MI && "RPTracker sync error!"); + RPTracker.recede(RegOpers); + } + + // Close the RPTracker to finalize live ins. + RPTracker.closeRegion(); + + return RPTracker.getPressure().MaxSetPressure; + }; + + unsigned GPRLimit = TRI->getRegPressureSetLimit( + *MBB->getParent(), RISCV::RegisterPressureSets::GPR); + + LLVM_DEBUG(dbgs() << "Register Pressure: " + << GetMBBPressure(MBB)[RISCV::RegisterPressureSets::GPR] + << "::" << GPRLimit << "\n"); + return GetMBBPressure(MBB)[RISCV::RegisterPressureSets::GPR] > GPRLimit; +} + static bool isFADD(unsigned Opc) { switch (Opc) { default: @@ -1527,42 +1570,28 @@ return RISCV::hasEqualFRM(Root, *MI); } -static bool -getFPFusedMultiplyPatterns(MachineInstr &Root, - SmallVectorImpl &Patterns, - bool DoRegPressureReduce) { - unsigned Opc = Root.getOpcode(); - bool IsFAdd = isFADD(Opc); - if (!IsFAdd && !isFSUB(Opc)) - return false; - bool Added = false; - if (canCombineFPFusedMultiply(Root, Root.getOperand(1), - DoRegPressureReduce)) { - Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_AX - : MachineCombinerPattern::FMSUB); - Added = true; - } - if (canCombineFPFusedMultiply(Root, Root.getOperand(2), - DoRegPressureReduce)) { - Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_XA - : MachineCombinerPattern::FNMSUB); - Added = true; - } - return Added; -} - -static bool getFPPatterns(MachineInstr &Root, - SmallVectorImpl &Patterns, - bool DoRegPressureReduce) { - return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce); -} - bool RISCVInstrInfo::getMachineCombinerPatterns( MachineInstr &Root, SmallVectorImpl &Patterns, bool DoRegPressureReduce) const { + unsigned Opc = Root.getOpcode(); + bool IsFAdd = isFADD(Opc); - if (getFPPatterns(Root, Patterns, DoRegPressureReduce)) - return true; + if (IsFAdd || isFSUB(Opc)) { + if (canCombineFPFusedMultiply(Root, Root.getOperand(1), + DoRegPressureReduce)) { + Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_AX + : MachineCombinerPattern::FMSUB); + return true; + } else if (canCombineFPFusedMultiply(Root, Root.getOperand(2), + DoRegPressureReduce)) { + Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_XA + : MachineCombinerPattern::FNMSUB); + return true; + } + } + + if (DoRegPressureReduce) + return false; return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, DoRegPressureReduce);