diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -169,6 +169,12 @@ FMULv4i32_indexed_OP2, FMULv8i16_indexed_OP1, FMULv8i16_indexed_OP2, + + // RISCV FMADD, FMSUB, FNMSUB patterns + FMADD_AX, + FMADD_XA, + FMSUB, + FNMSUB, }; } // end namespace llvm diff --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp --- a/llvm/lib/CodeGen/MachineCombiner.cpp +++ b/llvm/lib/CodeGen/MachineCombiner.cpp @@ -319,6 +319,10 @@ case MachineCombinerPattern::REASSOC_XMM_AMM_BMM: case MachineCombinerPattern::SUBADD_OP1: case MachineCombinerPattern::SUBADD_OP2: + case MachineCombinerPattern::FMADD_AX: + case MachineCombinerPattern::FMADD_XA: + case MachineCombinerPattern::FMSUB: + case MachineCombinerPattern::FNMSUB: return CombinerObjective::MustReduceDepth; case MachineCombinerPattern::REASSOC_XY_BCA: case MachineCombinerPattern::REASSOC_XY_BAC: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -196,6 +196,12 @@ finalizeInsInstrs(MachineInstr &Root, MachineCombinerPattern &P, SmallVectorImpl &InsInstrs) const override; + void genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const override; + protected: const RISCVSubtarget &STI; }; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -26,6 +26,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/RegisterScavenging.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/MC/MCInstBuilder.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/ErrorHandling.h" @@ -1176,6 +1177,17 @@ } } +static bool isFSUB(unsigned Opc) { + switch (Opc) { + default: + return false; + case RISCV::FSUB_H: + case RISCV::FSUB_S: + case RISCV::FSUB_D: + return true; + } +} + static bool isFMUL(unsigned Opc) { switch (Opc) { default: @@ -1211,6 +1223,33 @@ return RISCV::hasEqualFRM(Root, *MI); } +static bool canCombineFPFusedMultiply(const MachineInstr &Root, + const MachineOperand &MO, + bool DoRegPressureReduce) { + if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg())) + return false; + const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo(); + MachineInstr *MI = MRI.getVRegDef(MO.getReg()); + if (!MI || !isFMUL(MI->getOpcode())) + return false; + + if (!Root.getFlag(MachineInstr::MIFlag::FmContract) || + !MI->getFlag(MachineInstr::MIFlag::FmContract)) + return false; + + // Try combining even if fmul has more than one use as it eliminates + // dependency between fadd(fsub) and fmul. However, it can extend liveranges + // for fmul operands, so reject the transformation in register pressure + // reduction mode. + if (DoRegPressureReduce && !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) + return false; + + // Do not combine instructions from different basic blocks. + if (Root.getParent() != MI->getParent()) + return false; + return RISCV::hasEqualFRM(Root, *MI); +} + static bool getFPReassocPatterns(MachineInstr &Root, SmallVectorImpl &Patterns) { @@ -1228,25 +1267,148 @@ return Added; } -static bool getFPPatterns(MachineInstr &Root, - SmallVectorImpl &Patterns) { +static bool +getFPFusedMultiplyPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns, + bool DoRegPressureReduce) { unsigned Opc = Root.getOpcode(); - if (isAssociativeAndCommutativeFPOpcode(Opc)) - return getFPReassocPatterns(Root, Patterns); - return false; + bool IsFAdd = isFADD(Opc); + if (!IsFAdd && !isFSUB(Opc)) + return false; + bool Added = false; + if (canCombineFPFusedMultiply(Root, Root.getOperand(1), + DoRegPressureReduce)) { + Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_AX + : MachineCombinerPattern::FMSUB); + Added = true; + } + if (canCombineFPFusedMultiply(Root, Root.getOperand(2), + DoRegPressureReduce)) { + Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_XA + : MachineCombinerPattern::FNMSUB); + Added = true; + } + return Added; +} + +static bool getFPPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns, + bool DoRegPressureReduce) { + bool Added = getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce); + if (isAssociativeAndCommutativeFPOpcode(Root.getOpcode())) + Added |= getFPReassocPatterns(Root, Patterns); + return Added; } bool RISCVInstrInfo::getMachineCombinerPatterns( MachineInstr &Root, SmallVectorImpl &Patterns, bool DoRegPressureReduce) const { - if (getFPPatterns(Root, Patterns)) + if (getFPPatterns(Root, Patterns, DoRegPressureReduce)) return true; return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns, DoRegPressureReduce); } +static unsigned getFPFusedMultiplyOpcode(unsigned RootOpc, + MachineCombinerPattern Pattern) { + switch (RootOpc) { + default: + llvm_unreachable("Unexpected opcode"); + case RISCV::FADD_H: + return RISCV::FMADD_H; + case RISCV::FADD_S: + return RISCV::FMADD_S; + case RISCV::FADD_D: + return RISCV::FMADD_D; + case RISCV::FSUB_H: + return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_H + : RISCV::FNMSUB_H; + case RISCV::FSUB_S: + return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_S + : RISCV::FNMSUB_S; + case RISCV::FSUB_D: + return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_D + : RISCV::FNMSUB_D; + } +} + +static unsigned getAddendOperandIdx(MachineCombinerPattern Pattern) { + switch (Pattern) { + default: + llvm_unreachable("Unexpected pattern"); + case MachineCombinerPattern::FMADD_AX: + case MachineCombinerPattern::FMSUB: + return 2; + case MachineCombinerPattern::FMADD_XA: + case MachineCombinerPattern::FNMSUB: + return 1; + } +} + +static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev, + MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs) { + MachineFunction *MF = Root.getMF(); + MachineRegisterInfo &MRI = MF->getRegInfo(); + const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo(); + + MachineOperand &Mul1 = Prev.getOperand(1); + MachineOperand &Mul2 = Prev.getOperand(2); + MachineOperand &Dst = Root.getOperand(0); + MachineOperand &Addend = Root.getOperand(getAddendOperandIdx(Pattern)); + + Register DstReg = Dst.getReg(); + unsigned FusedOpc = getFPFusedMultiplyOpcode(Root.getOpcode(), Pattern); + auto IntersectedFlags = Root.getFlags() & Prev.getFlags(); + DebugLoc MergedLoc = + DILocation::getMergedLocation(Root.getDebugLoc(), Prev.getDebugLoc()); + + MachineInstrBuilder MIB = + BuildMI(*MF, MergedLoc, TII->get(FusedOpc), DstReg) + .addReg(Mul1.getReg(), getKillRegState(Mul1.isKill())) + .addReg(Mul2.getReg(), getKillRegState(Mul2.isKill())) + .addReg(Addend.getReg(), getKillRegState(Addend.isKill())) + .setMIFlags(IntersectedFlags); + + // Mul operands are not killed anymore. + Mul1.setIsKill(false); + Mul2.setIsKill(false); + + InsInstrs.push_back(MIB); + if (MRI.hasOneNonDBGUse(Prev.getOperand(0).getReg())) + DelInstrs.push_back(&Prev); + DelInstrs.push_back(&Root); +} + +void RISCVInstrInfo::genAlternativeCodeSequence( + MachineInstr &Root, MachineCombinerPattern Pattern, + SmallVectorImpl &InsInstrs, + SmallVectorImpl &DelInstrs, + DenseMap &InstrIdxForVirtReg) const { + MachineRegisterInfo &MRI = Root.getMF()->getRegInfo(); + switch (Pattern) { + default: + TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs, + DelInstrs, InstrIdxForVirtReg); + return; + case MachineCombinerPattern::FMADD_AX: + case MachineCombinerPattern::FMSUB: { + MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(1).getReg()); + combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs); + return; + } + case MachineCombinerPattern::FMADD_XA: + case MachineCombinerPattern::FNMSUB: { + MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(2).getReg()); + combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs); + return; + } + } +} + bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI, StringRef &ErrInfo) const { MCInstrDesc const &Desc = MI.getDesc(); diff --git a/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll b/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll --- a/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll +++ b/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll @@ -95,8 +95,8 @@ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:fpr64 = COPY $f11_d ; CHECK-NEXT: [[COPY2:%[0-9]+]]:fpr64 = COPY $f10_d ; CHECK-NEXT: [[FMUL_D:%[0-9]+]]:fpr64 = contract nofpexcept FMUL_D [[COPY2]], [[COPY1]], 7, implicit $frm - ; CHECK-NEXT: [[FADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FADD_D [[FMUL_D]], [[COPY]], 7, implicit $frm - ; CHECK-NEXT: [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FADD_D]], [[FMUL_D]], 7, implicit $frm + ; CHECK-NEXT: [[FMADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FMADD_D [[COPY2]], [[COPY1]], [[COPY]], 7, implicit $frm + ; CHECK-NEXT: [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FMADD_D]], [[FMUL_D]], 7, implicit $frm ; CHECK-NEXT: $f10_d = COPY [[FDIV_D]] ; CHECK-NEXT: PseudoRET implicit $f10_d %t0 = fmul contract double %a0, %a1 diff --git a/llvm/test/CodeGen/RISCV/machine-combiner.ll b/llvm/test/CodeGen/RISCV/machine-combiner.ll --- a/llvm/test/CodeGen/RISCV/machine-combiner.ll +++ b/llvm/test/CodeGen/RISCV/machine-combiner.ll @@ -188,10 +188,9 @@ define double @test_fmadd1(double %a0, double %a1, double %a2, double %a3) { ; CHECK-LABEL: test_fmadd1: ; CHECK: # %bb.0: -; CHECK-NEXT: fmul.d ft0, fa0, fa1 -; CHECK-NEXT: fadd.d ft1, ft0, fa2 -; CHECK-NEXT: fadd.d ft0, fa3, ft0 -; CHECK-NEXT: fadd.d fa0, ft1, ft0 +; CHECK-NEXT: fmadd.d ft0, fa0, fa1, fa2 +; CHECK-NEXT: fmadd.d ft1, fa0, fa1, fa3 +; CHECK-NEXT: fadd.d fa0, ft0, ft1 ; CHECK-NEXT: ret %t0 = fmul contract double %a0, %a1 %t1 = fadd contract double %t0, %a2 @@ -204,7 +203,7 @@ ; CHECK-LABEL: test_fmadd2: ; CHECK: # %bb.0: ; CHECK-NEXT: fmul.d ft0, fa0, fa1 -; CHECK-NEXT: fadd.d ft1, ft0, fa2 +; CHECK-NEXT: fmadd.d ft1, fa0, fa1, fa2 ; CHECK-NEXT: fdiv.d fa0, ft1, ft0 ; CHECK-NEXT: ret %t0 = fmul contract double %a0, %a1 @@ -217,7 +216,7 @@ ; CHECK-LABEL: test_fmsub: ; CHECK: # %bb.0: ; CHECK-NEXT: fmul.d ft0, fa0, fa1 -; CHECK-NEXT: fsub.d ft1, ft0, fa2 +; CHECK-NEXT: fmsub.d ft1, fa0, fa1, fa2 ; CHECK-NEXT: fdiv.d fa0, ft1, ft0 ; CHECK-NEXT: ret %t0 = fmul contract double %a0, %a1 @@ -230,7 +229,7 @@ ; CHECK-LABEL: test_fnmsub: ; CHECK: # %bb.0: ; CHECK-NEXT: fmul.d ft0, fa0, fa1 -; CHECK-NEXT: fsub.d ft1, fa2, ft0 +; CHECK-NEXT: fnmsub.d ft1, fa0, fa1, fa2 ; CHECK-NEXT: fdiv.d fa0, ft1, ft0 ; CHECK-NEXT: ret %t0 = fmul contract double %a0, %a1