diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -1180,7 +1180,8 @@ const MachineBasicBlock *MBB) const; /// Return true when \P Inst has reassociable sibling. - bool hasReassociableSibling(const MachineInstr &Inst, bool &Commuted) const; + virtual bool hasReassociableSibling(const MachineInstr &Inst, + bool &Commuted) const; /// When getMachineCombinerPatterns() finds patterns, this function generates /// the instructions that could replace the original code sequence. The client diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -202,6 +202,11 @@ SmallVectorImpl &DelInstrs, DenseMap &InstrIdxForVirtReg) const override; + bool hasReassociableSibling(const MachineInstr &Inst, + bool &Commuted) const override; + + bool isAssociativeAndCommutative(const MachineInstr &Inst) const override; + protected: const RISCVSubtarget &STI; }; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1199,28 +1199,25 @@ } } -static bool isAssociativeAndCommutativeFPOpcode(unsigned Opc) { - return isFADD(Opc) || isFMUL(Opc); -} - -static bool canReassociate(MachineInstr &Root, MachineOperand &MO) { - if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg())) - return false; - MachineRegisterInfo &MRI = Root.getMF()->getRegInfo(); - MachineInstr *MI = MRI.getVRegDef(MO.getReg()); - if (!MI || !MRI.hasOneNonDBGUse(MO.getReg())) +bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst, + bool &Commuted) const { + if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted)) return false; - if (MI->getOpcode() != Root.getOpcode()) - return false; + const MachineRegisterInfo &MRI = Inst.getMF()->getRegInfo(); + unsigned OperandIdx = Commuted ? 2 : 1; + MachineInstr &Sibling = *MRI.getVRegDef(Inst.getOperand(OperandIdx).getReg()); - if (!Root.getFlag(MachineInstr::MIFlag::FmReassoc) || - !Root.getFlag(MachineInstr::MIFlag::FmNsz) || - !MI->getFlag(MachineInstr::MIFlag::FmReassoc) || - !MI->getFlag(MachineInstr::MIFlag::FmNsz)) - return false; + return RISCV::hasEqualFRM(Inst, Sibling); +} - return RISCV::hasEqualFRM(Root, *MI); +bool RISCVInstrInfo::isAssociativeAndCommutative( + const MachineInstr &Inst) const { + unsigned Opc = Inst.getOpcode(); + if (isFADD(Opc) || isFMUL(Opc)) + return Inst.getFlag(MachineInstr::MIFlag::FmReassoc) && + Inst.getFlag(MachineInstr::MIFlag::FmNsz); + return false; } static bool canCombineFPFusedMultiply(const MachineInstr &Root, @@ -1250,23 +1247,6 @@ return RISCV::hasEqualFRM(Root, *MI); } -static bool -getFPReassocPatterns(MachineInstr &Root, - SmallVectorImpl &Patterns) { - bool Added = false; - if (canReassociate(Root, Root.getOperand(1))) { - Patterns.push_back(MachineCombinerPattern::REASSOC_AX_BY); - Patterns.push_back(MachineCombinerPattern::REASSOC_XA_BY); - Added = true; - } - if (canReassociate(Root, Root.getOperand(2))) { - Patterns.push_back(MachineCombinerPattern::REASSOC_AX_YB); - Patterns.push_back(MachineCombinerPattern::REASSOC_XA_YB); - Added = true; - } - return Added; -} - static bool getFPFusedMultiplyPatterns(MachineInstr &Root, SmallVectorImpl &Patterns, @@ -1294,10 +1274,7 @@ static bool getFPPatterns(MachineInstr &Root, SmallVectorImpl &Patterns, bool DoRegPressureReduce) { - bool Added = getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce); - if (isAssociativeAndCommutativeFPOpcode(Root.getOpcode())) - Added |= getFPReassocPatterns(Root, Patterns); - return Added; + return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce); } bool RISCVInstrInfo::getMachineCombinerPatterns(