diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -62,6 +62,7 @@ class TargetSchedModel; class TargetSubtargetInfo; enum class MachineCombinerPattern; +enum class MachineTraceStrategy; template class SmallVectorImpl; @@ -1252,6 +1253,9 @@ /// Return true when a target supports MachineCombiner. virtual bool useMachineCombiner() const { return false; } + /// Return a strategy that MachineCombiner must use when creating traces. + virtual MachineTraceStrategy getMachineCombinerTraceStrategy() const; + /// Return true if the given SDNode can be copied during scheduling /// even if it has glue. virtual bool canCopyGluedNodeDuringSchedule(SDNode *N) const { return false; } diff --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp --- a/llvm/lib/CodeGen/MachineCombiner.cpp +++ b/llvm/lib/CodeGen/MachineCombiner.cpp @@ -96,7 +96,8 @@ bool isTransientMI(const MachineInstr *MI); unsigned getDepth(SmallVectorImpl &InsInstrs, DenseMap &InstrIdxForVirtReg, - MachineTraceMetrics::Trace BlockTrace); + MachineTraceMetrics::Trace BlockTrace, + const MachineBasicBlock &MBB); unsigned getLatency(MachineInstr *Root, MachineInstr *NewRoot, MachineTraceMetrics::Trace BlockTrace); bool @@ -208,7 +209,8 @@ unsigned MachineCombiner::getDepth(SmallVectorImpl &InsInstrs, DenseMap &InstrIdxForVirtReg, - MachineTraceMetrics::Trace BlockTrace) { + MachineTraceMetrics::Trace BlockTrace, + const MachineBasicBlock &MBB) { SmallVector InstrDepth; assert(TSchedModel.hasInstrSchedModelOrItineraries() && "Missing machine model\n"); @@ -241,7 +243,9 @@ InstrPtr, UseIdx); } else { MachineInstr *DefInstr = getOperandDef(MO); - if (DefInstr) { + if (DefInstr && (TII->getMachineCombinerTraceStrategy() != + MachineTraceStrategy::TS_Local || + DefInstr->getParent() == &MBB)) { DepthOp = BlockTrace.getInstrCycles(*DefInstr).Depth; if (!isTransientMI(DefInstr)) LatencyOp = TSchedModel.computeOperandLatency( @@ -383,7 +387,8 @@ assert(TSchedModel.hasInstrSchedModelOrItineraries() && "Missing machine model\n"); // Get depth and latency of NewRoot and Root. - unsigned NewRootDepth = getDepth(InsInstrs, InstrIdxForVirtReg, BlockTrace); + unsigned NewRootDepth = + getDepth(InsInstrs, InstrIdxForVirtReg, BlockTrace, *MBB); unsigned RootDepth = BlockTrace.getInstrCycles(*Root).Depth; LLVM_DEBUG(dbgs() << " Dependence data for " << *Root << "\tNewRootDepth: " @@ -594,7 +599,7 @@ // Check if the block is in a loop. const MachineLoop *ML = MLI->getLoopFor(MBB); if (!TraceEnsemble) - TraceEnsemble = Traces->getEnsemble(MachineTraceStrategy::TS_MinInstrCount); + TraceEnsemble = Traces->getEnsemble(TII->getMachineCombinerTraceStrategy()); SparseSet RegUnits; RegUnits.setUniverse(TRI->getNumRegUnits()); diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp --- a/llvm/lib/CodeGen/TargetInstrInfo.cpp +++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp @@ -19,6 +19,7 @@ #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/MachineScheduler.h" +#include "llvm/CodeGen/MachineTraceMetrics.h" #include "llvm/CodeGen/PseudoSourceValue.h" #include "llvm/CodeGen/ScoreboardHazardRecognizer.h" #include "llvm/CodeGen/StackMaps.h" @@ -1047,6 +1048,10 @@ reassociateOps(Root, *Prev, Pattern, InsInstrs, DelInstrs, InstIdxForVirtReg); } +MachineTraceStrategy TargetInstrInfo::getMachineCombinerTraceStrategy() const { + return MachineTraceStrategy::TS_MinInstrCount; +} + bool TargetInstrInfo::isReallyTriviallyReMaterializableGeneric( const MachineInstr &MI) const { const MachineFunction &MF = *MI.getMF(); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -194,6 +194,8 @@ bool useMachineCombiner() const override { return true; } + MachineTraceStrategy getMachineCombinerTraceStrategy() const override; + void setSpecialOperandAttr(MachineInstr &OldMI1, MachineInstr &OldMI2, MachineInstr &NewMI1, MachineInstr &NewMI2) const override; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/MachineTraceMetrics.h" #include "llvm/CodeGen/RegisterScavenging.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/MC/MCInstBuilder.h" @@ -44,6 +45,16 @@ "riscv-prefer-whole-register-move", cl::init(false), cl::Hidden, cl::desc("Prefer whole register move for vector registers.")); +static cl::opt ForceMachineCombinerStrategy( + "riscv-force-machine-combiner-strategy", cl::Hidden, + cl::desc("Force machine combiner to use a specific strategy for machine " + "trace metrics evaluation."), + cl::init(MachineTraceStrategy::TS_NumStrategies), + cl::values(clEnumValN(MachineTraceStrategy::TS_Local, "local", + "Local strategy."), + clEnumValN(MachineTraceStrategy::TS_MinInstrCount, "min-instr", + "MinInstrCount strategy."))); + namespace llvm::RISCVVPseudosTable { using namespace RISCV; @@ -1263,6 +1274,17 @@ return std::nullopt; } +MachineTraceStrategy RISCVInstrInfo::getMachineCombinerTraceStrategy() const { + if (ForceMachineCombinerStrategy == MachineTraceStrategy::TS_NumStrategies) + // The option is unused. Choose MinInstrCount strategy only for out of order + // cores. + return STI.getSchedModel().isOutOfOrder() + ? MachineTraceStrategy::TS_MinInstrCount + : MachineTraceStrategy::TS_Local; + // The strategy was forced by the option. + return ForceMachineCombinerStrategy; +} + void RISCVInstrInfo::setSpecialOperandAttr(MachineInstr &OldMI1, MachineInstr &OldMI2, MachineInstr &NewMI1, diff --git a/llvm/test/CodeGen/RISCV/machine-combiner.ll b/llvm/test/CodeGen/RISCV/machine-combiner.ll --- a/llvm/test/CodeGen/RISCV/machine-combiner.ll +++ b/llvm/test/CodeGen/RISCV/machine-combiner.ll @@ -1,7 +1,11 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=riscv64 -mattr=+d -verify-machineinstrs -mcpu=sifive-u74 \ -; RUN: -O1 -riscv-enable-machine-combiner=true < %s | \ -; RUN: FileCheck %s +; RUN: -O1 -riscv-enable-machine-combiner=true -riscv-force-machine-combiner-strategy=local < %s | \ +; RUN: FileCheck %s --check-prefixes=CHECK,CHECK_LOCAL + +; RUN: llc -mtriple=riscv64 -mattr=+d -verify-machineinstrs -mcpu=sifive-u74 \ +; RUN: -O1 -riscv-enable-machine-combiner=true -riscv-force-machine-combiner-strategy=min-instr < %s | \ +; RUN: FileCheck %s --check-prefixes=CHECK,CHECK_GLOBAL define double @test_reassoc_fadd1(double %a0, double %a1, double %a2, double %a3) { ; CHECK-LABEL: test_reassoc_fadd1: @@ -393,3 +397,43 @@ %t2 = fsub nsz reassoc double %a3, %t1 ret double %t2 } + +define double @test_fmadd_strategy(double %a0, double %a1, double %a2, double %a3, i64 %flag) { +; CHECK_LOCAL-LABEL: test_fmadd_strategy: +; CHECK_LOCAL: # %bb.0: # %entry +; CHECK_LOCAL-NEXT: fmv.d ft0, fa0 +; CHECK_LOCAL-NEXT: fsub.d ft1, fa0, fa1 +; CHECK_LOCAL-NEXT: fmul.d fa0, ft1, fa2 +; CHECK_LOCAL-NEXT: andi a0, a0, 1 +; CHECK_LOCAL-NEXT: beqz a0, .LBB16_2 +; CHECK_LOCAL-NEXT: # %bb.1: # %entry +; CHECK_LOCAL-NEXT: fmul.d ft1, ft0, fa1 +; CHECK_LOCAL-NEXT: fmadd.d ft0, ft0, fa1, fa0 +; CHECK_LOCAL-NEXT: fsub.d fa0, ft0, ft1 +; CHECK_LOCAL-NEXT: .LBB16_2: # %entry +; CHECK_LOCAL-NEXT: ret +; +; CHECK_GLOBAL-LABEL: test_fmadd_strategy: +; CHECK_GLOBAL: # %bb.0: # %entry +; CHECK_GLOBAL-NEXT: fmv.d ft0, fa0 +; CHECK_GLOBAL-NEXT: fsub.d ft1, fa0, fa1 +; CHECK_GLOBAL-NEXT: fmul.d fa0, ft1, fa2 +; CHECK_GLOBAL-NEXT: andi a0, a0, 1 +; CHECK_GLOBAL-NEXT: beqz a0, .LBB16_2 +; CHECK_GLOBAL-NEXT: # %bb.1: # %entry +; CHECK_GLOBAL-NEXT: fmul.d ft0, ft0, fa1 +; CHECK_GLOBAL-NEXT: fadd.d ft1, ft0, fa0 +; CHECK_GLOBAL-NEXT: fsub.d fa0, ft1, ft0 +; CHECK_GLOBAL-NEXT: .LBB16_2: # %entry +; CHECK_GLOBAL-NEXT: ret +entry: + %sub = fsub contract double %a0, %a1 + %mul = fmul contract double %sub, %a2 + %and = and i64 %flag, 1 + %tobool.not = icmp eq i64 %and, 0 + %mul2 = fmul contract double %a0, %a1 + %add = fadd contract double %mul2, %mul + %sub3 = fsub contract double %add, %mul2 + %retval.0 = select i1 %tobool.not, double %mul, double %sub3 + ret double %retval.0 +}