Index: include/llvm/Target/TargetInstrInfo.h =================================================================== --- include/llvm/Target/TargetInstrInfo.h +++ include/llvm/Target/TargetInstrInfo.h @@ -630,6 +630,36 @@ return false; } + /// Return true if it's profitable to branch from head block specified by MBB + /// that dominates the TBB. + /// + /// e.g: MBB + /// / | + /// / | + /// | TBB + /// | / | + /// | / | + /// Tail | + /// | | + /// ... ... + /// + /// The Head block is terminated by a br.cond instruction, and the TBB block + /// contains compare + br.cond. Tail must be a successor of both. + /// @param MBB Head block terminated by br.cond instruction. + /// @param TBB Contains the compare instruction and br.cond + /// @param CmpMI MI compare instruction in TBB to analyze. + /// @param CmpBBDepth Instruction depths for all trace instructions above TBB + /// @param HeadDepth Instruction depths for all trace instructions above MBB + /// @param Penalty Branch misprediction penalty + /// @param Probability Probability of path taken from head block MBB to TBB. + virtual bool isProfitableToBranch(MachineBasicBlock *MBB, + MachineBasicBlock *TBB, MachineInstr *CmpMI, + unsigned CmpBBDepth, unsigned HeadDepth, + unsigned Penalty, + BranchProbability Probability) const { + return false; + } + /// Return true if it is possible to insert a select /// instruction that chooses between TrueReg and FalseReg based on the /// condition code in Cond. Index: lib/Target/AArch64/AArch64ConditionalCompares.cpp =================================================================== --- lib/Target/AArch64/AArch64ConditionalCompares.cpp +++ lib/Target/AArch64/AArch64ConditionalCompares.cpp @@ -723,6 +723,7 @@ class AArch64ConditionalCompares : public MachineFunctionPass { const TargetInstrInfo *TII; const TargetRegisterInfo *TRI; + const MachineBranchProbabilityInfo *MBPI; MCSchedModel SchedModel; // Does the proceeded function has Oz attribute. bool MinSize; @@ -854,6 +855,14 @@ Trace.getInstrCycles(*CmpConv.CmpBB->getFirstTerminator()).Depth; DEBUG(dbgs() << "Head depth: " << HeadDepth << "\nCmpBB depth: " << CmpBBDepth << '\n'); + BranchProbability Prediction = + MBPI->getEdgeProbability(CmpConv.Head, CmpConv.CmpBB); + if (TII->isProfitableToBranch(CmpConv.Head, CmpConv.CmpBB, CmpConv.CmpMI, + CmpBBDepth, HeadDepth, + SchedModel.MispredictPenalty, Prediction)) { + DEBUG(dbgs() << "Branch is profitable than predication.\n"); + return false; + } if (CmpBBDepth > HeadDepth + DelayLimit) { DEBUG(dbgs() << "Branch delay would be larger than " << DelayLimit << " cycles.\n"); @@ -893,6 +902,7 @@ << "********** Function: " << MF.getName() << '\n'); TII = MF.getSubtarget().getInstrInfo(); TRI = MF.getSubtarget().getRegisterInfo(); + MBPI = &getAnalysis(); SchedModel = MF.getSubtarget().getSchedModel(); MRI = &MF.getRegInfo(); DomTree = &getAnalysis(); Index: lib/Target/AArch64/AArch64InstrInfo.h =================================================================== --- lib/Target/AArch64/AArch64InstrInfo.h +++ lib/Target/AArch64/AArch64InstrInfo.h @@ -184,6 +184,11 @@ bool expandPostRAPseudo(MachineBasicBlock::iterator MI) const override; + bool isProfitableToBranch(MachineBasicBlock *MBB, MachineBasicBlock *TBB, + MachineInstr *CmpMI, unsigned CmpBBDepth, + unsigned HeadDepth, unsigned Penalty, + BranchProbability Probability) const override; + std::pair decomposeMachineOperandsTargetFlags(unsigned TF) const override; ArrayRef> Index: lib/Target/AArch64/AArch64InstrInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64InstrInfo.cpp +++ lib/Target/AArch64/AArch64InstrInfo.cpp @@ -20,6 +20,7 @@ #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/PseudoSourceValue.h" #include "llvm/MC/MCInst.h" +#include "llvm/Support/BranchProbability.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TargetRegistry.h" @@ -3063,3 +3064,98 @@ {MO_CONSTPOOL, "aarch64-constant-pool"}}; return makeArrayRef(TargetFlags); } + +bool AArch64InstrInfo::isProfitableToBranch( + MachineBasicBlock *MBB, MachineBasicBlock *TBB, MachineInstr *CmpMI, + unsigned CmpBBDepth, unsigned HeadDepth, unsigned BranchMissPenalty, + BranchProbability Probability) const { + // Branching is profitable on Kryo subtarget particularly when the branch is + // easily predictable. Only run this for Kryo subtarget. + if (!Subtarget.isKryo()) + return false; + + // Heuristic: It is not profitable if the branch is hard to predict. + // Set a limit on the branch probability we will accept. + unsigned ScalingFactor = 100; + unsigned BranchCost = Probability.scale(ScalingFactor); + if (BranchCost < 75) + return false; + + // Heuristic: If the compare instruction is a ZBranch then + // branching can pessimize by replacing a ccmp with cbz/nz. + switch (CmpMI->getOpcode()) { + default: + break; + case AArch64::CBZW: + case AArch64::CBZX: + case AArch64::CBNZW: + case AArch64::CBNZX: + return false; + } + + // Heuristic: If the dominating head block contains hazardous + // instructions, the actual latency is variable. + // Instructions that read the PC such as ADRP, ADR and load-literal can affect + // direct branches, so it might be cheaper to speculate. + unsigned HazardCost = 0; + unsigned CondCost = 0; + int FI = INT_MIN; + for (auto &MI : *MBB) { + if (isLoadFromStackSlot(&MI, FI)) { + ++HazardCost; + } else { + switch (MI.getOpcode()) { + default: + break; + case AArch64::CBZW: + case AArch64::CBZX: + case AArch64::CBNZW: + case AArch64::CBNZX: + case AArch64::FDIVSrr: + case AArch64::FDIVDrr: + case AArch64::FSQRTSr: + case AArch64::FSQRTDr: + case AArch64::SDIVWr: + case AArch64::SDIVXr: + case AArch64::UDIVWr: + case AArch64::UDIVXr: + case AArch64::ADRP: + case AArch64::ADR: + case AArch64::LDRSWui: + case AArch64::MOVaddr: + case AArch64::MOVaddrJT: + case AArch64::MOVaddrCP: + case AArch64::MOVaddrBA: + case AArch64::MOVaddrTLS: + case AArch64::MOVaddrEXT: + ++HazardCost; + break; + case AArch64::CSINVWr: + case AArch64::CSINVXr: + case AArch64::CSINCWr: + case AArch64::CSINCXr: + case AArch64::CSELWr: + case AArch64::CSELXr: + case AArch64::CSNEGWr: + case AArch64::CSNEGXr: + case AArch64::FCSELSrrr: + case AArch64::FCSELDrrr: + ++CondCost; + break; + } + } + } + if (HazardCost > CondCost) + return false; + + // Heuristic: The compare conversion delays the execution of the branch + // instruction because we must wait for the inputs to the second compare as + // well. The branch has no dependent instructions, but delaying it increases + // the cost of a misprediction. + // + // Using branch probability set a limit on the delay we will accept. + unsigned DelayLimit = Probability.scale(BranchMissPenalty) / 2; + if (CmpBBDepth > HeadDepth + DelayLimit) + return true; + return false; +} Index: test/CodeGen/AArch64/aarch64-branch-heuristics.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aarch64-branch-heuristics.ll @@ -0,0 +1,62 @@ +; RUN: llc < %s -mcpu=kryo -verify-machineinstrs -aarch64-ccmp | FileCheck %s +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +%struct.arc = type { i64, %struct.node*, %struct.node*, i32, %struct.arc*, %struct.arc*, i64, i64 } +%struct.node = type { i64, i32, %struct.node*, %struct.node*, %struct.node*, %struct.node*, %struct.arc*, %struct.arc*, %struct.arc*, %struct.arc*, i64, i64, i32, i32 } +%struct.basket = type { %struct.arc*, i64, i64 } + +; CHECK: primal_bea_mpp +; CHECK: %if.then34 +; CHECK: cmp x{{[0-9]+}}, #1 +; CHECK-NEXT: b.ge +; CHECK: %if.then34.if.else.exit +; CHECK: cmp w{{[0-9]+}}, #2 +; CHECK-NEXT: b.ne +; CHECK-NEXT-NOT: ccmp +; Function Attrs: nounwind +define void @primal_bea_mpp() #0 { +entry: + br label %for.body + +for.body: ; preds = %for.inc, %entry + %arc = phi %struct.arc* [ %add.ptr60, %for.inc ], [ undef, %entry ] + %ident32 = getelementptr inbounds %struct.arc, %struct.arc* %arc, i64 0, i32 3 + %ident32.load = load i32, i32* %ident32, align 8 + %cmp33 = icmp sgt i32 %ident32.load, 0 + br i1 %cmp33, label %if.then34, label %for.inc + +if.then34: ; preds = %for.body + %cost35 = getelementptr inbounds %struct.arc, %struct.arc* %arc, i64 0, i32 0 + %0 = load i64, i64* %cost35, align 8 + %tail36 = getelementptr inbounds %struct.arc, %struct.arc* %arc, i64 0, i32 1 + %1 = load %struct.node*, %struct.node** %tail36, align 8 + %potential37 = getelementptr inbounds %struct.node, %struct.node* %1, i64 0, i32 0 + %2 = load i64, i64* %potential37, align 8 + %sub38 = sub nsw i64 %0, %2 + %head39 = getelementptr inbounds %struct.arc, %struct.arc* %arc, i64 0, i32 2 + %3= load %struct.node*, %struct.node** %head39, align 8 + %potential40 = getelementptr inbounds %struct.node, %struct.node* %3, i64 0, i32 0 + %4 = load i64, i64* %potential40, align 8 + %add41 = add nsw i64 %4, %sub38 + %cmp.i = icmp sgt i64 %add41, 0 + br i1 %cmp.i, label %land.lhs.true.i, label %if.then34.if.else.exit + +land.lhs.true.i: ; preds = %if.then34 + %cmp1.i = icmp eq i32 %ident32.load, 1 + br i1 %cmp1.i, label %if.then43, label %for.inc + +if.then34.if.else.exit: ; preds = %if.then34 + %cmp2.i = icmp sgt i64 %add41, 0 + %cmp4.i = icmp eq i32 %ident32.load, 2 + %cmp4.i. = and i1 %cmp4.i, %cmp2.i + br i1 %cmp4.i., label %if.then43, label %for.inc + +if.then43: ; preds = %if.then34 + %abs_cost56 = getelementptr inbounds %struct.basket, %struct.basket* undef, i64 0, i32 2 + br label %for.inc + +for.inc: ; preds = %if.then43, %if.then34, %for.body + %add.ptr60 = getelementptr inbounds %struct.arc, %struct.arc* %arc, i64 undef + br label %for.body +}