Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -53,6 +53,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Target/TargetCallingConv.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/CodeGen/LiveInterval.h" #include #include #include @@ -3398,6 +3399,26 @@ return false; } + /// The target can specify whether a callee-saved register should be used + /// rather than spliting the live range. Default behaviour is yes. + virtual bool useCSRInsteadOfSplit(const LiveInterval &LI) const { + return true; + } + + /// The number of splits in user blocks which could be allowed to be traded + /// for the spill of the CSR in the entry block when detering the first use + /// of CSR is prefered. + virtual unsigned getNumberOfTradableSplitsAgainstCSR() const { + return 0; + } + + /// The number of spills in user blocks which could be allowed to be traded + /// for the spill of the CSR in the entry block when detering the first use + /// of CSR is prefered. + virtual unsigned getNumberOfTradableSpillsAgainstCSR() const { + return 0; + } + /// Lower TLS global address SDNode for target independent emulated TLS model. virtual SDValue LowerToTLSEmulatedModel(const GlobalAddressSDNode *GA, SelectionDAG &DAG) const; Index: lib/CodeGen/RegAllocGreedy.cpp =================================================================== --- lib/CodeGen/RegAllocGreedy.cpp +++ lib/CodeGen/RegAllocGreedy.cpp @@ -66,6 +66,7 @@ #include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetInstrInfo.h" +#include "llvm/Target/TargetLowering.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetRegisterInfo.h" #include "llvm/Target/TargetSubtargetInfo.h" @@ -148,6 +149,7 @@ // Shortcuts to some useful interface. const TargetInstrInfo *TII; const TargetRegisterInfo *TRI; + const TargetLowering *TLI; RegisterClassInfo RCI; // analyses @@ -416,6 +418,7 @@ unsigned PhysReg, unsigned &CostPerUseLimit, SmallVectorImpl &NewVRegs); void initializeCSRCost(); + BlockFrequency getFirstTimeCSRCost(LiveInterval &VirtReg); unsigned tryBlockSplit(LiveInterval&, AllocationOrder&, SmallVectorImpl&); unsigned tryInstructionSplit(LiveInterval&, AllocationOrder&, @@ -1502,6 +1505,7 @@ } Cost += calcGlobalSplitCost(Cand); + DEBUG({ dbgs() << ", total = "; MBFI->printBlockFreq(dbgs(), Cost) << " with bundles"; @@ -2334,9 +2338,9 @@ SmallVectorImpl &NewVRegs) { if (getStage(VirtReg) == RS_Spill && VirtReg.isSpillable()) { // We choose spill over using the CSR for the first time if the spill cost - // is lower than CSRCost. + // is lower than the CSR cost. SA->analyze(&VirtReg); - if (calcSpillCost() >= CSRCost) + if (calcSpillCost() >= getFirstTimeCSRCost(VirtReg)) return PhysReg; // We are going to spill, set CostPerUseLimit to 1 to make sure that @@ -2346,10 +2350,10 @@ } if (getStage(VirtReg) < RS_Split) { // We choose pre-splitting over using the CSR for the first time if - // the cost of splitting is lower than CSRCost. + // the cost of splitting is lower than the CSR cost. SA->analyze(&VirtReg); unsigned NumCands = 0; - BlockFrequency BestCost = CSRCost; // Don't modify CSRCost. + BlockFrequency BestCost = getFirstTimeCSRCost(VirtReg); unsigned BestCand = calculateRegionSplitCost(VirtReg, Order, BestCost, NumCands, true /*IgnoreCSR*/); if (BestCand == NoCand) @@ -2368,6 +2372,37 @@ SetOfBrokenHints.remove(&LI); } +// Increase the cost for the first use of a callee-saved register if the live +// range of the value spans basic blocks in which we'd prefer not to use one. +// This will often defer use of a CSR and give shrink-wrapping an opportunity +// to sink/hoist the save/restore from entry/exit blocks respectively. +BlockFrequency RAGreedy::getFirstTimeCSRCost(LiveInterval &VirtReg) { + BlockFrequency BestCost = CSRCost; + if (TLI->useCSRInsteadOfSplit(VirtReg)) + return BestCost; + + // Conservatively, we try to increase the CSR cost only when all blocks in + // the live range have no call. + ArrayRef UseBlocks = SA->getUseBlocks(); + for (int i = 0, e = UseBlocks.size(); i < e; ++i) + for (auto &MI : *UseBlocks[i].MBB) + if (MI.isCall()) + return BestCost; + + // Now, we prefer to defering the first use of CSR. Try to increase the CSR + // cost by multipling the frequency of entry block with the number of tradable + // splits or spills. + uint64_t EntryFreq = MBFI->getEntryFreq(); + if (getStage(VirtReg) == RS_Spill && VirtReg.isSpillable()) + return std::max(BestCost.getFrequency(), + EntryFreq * TLI->getNumberOfTradableSpillsAgainstCSR()); + else if (getStage(VirtReg) < RS_Split) + return std::max(BestCost.getFrequency(), + EntryFreq * TLI->getNumberOfTradableSplitsAgainstCSR()); + else + llvm_unreachable ("Unexpected stage to find the first time CSR cost."); +} + void RAGreedy::initializeCSRCost() { // We use the larger one out of the command-line option and the value report // by TRI. @@ -2568,8 +2603,8 @@ // When NewVRegs is not empty, we may have made decisions such as evicting // a virtual register, go with the earlier decisions and use the physical // register. - if (CSRCost.getFrequency() && isUnusedCalleeSavedReg(PhysReg) && - NewVRegs.empty()) { + if ((CSRCost.getFrequency() || !TLI->useCSRInsteadOfSplit(VirtReg)) && + isUnusedCalleeSavedReg(PhysReg) && NewVRegs.empty()) { unsigned CSRReg = tryAssignCSRFirstTime(VirtReg, Order, PhysReg, CostPerUseLimit, NewVRegs); if (CSRReg || !NewVRegs.empty()) @@ -2723,6 +2758,7 @@ MF = &mf; TRI = MF->getSubtarget().getRegisterInfo(); TII = MF->getSubtarget().getInstrInfo(); + TLI = MF->getSubtarget().getTargetLowering(); RCI.runOnMachineFunction(mf); EnableLocalReassign = EnableLocalReassignment || @@ -2747,7 +2783,6 @@ AA = &getAnalysis().getAAResults(); initializeCSRCost(); - calculateSpillWeightsAndHints(*LIS, mf, VRM, *Loops, *MBFI); DEBUG(LIS->dump()); Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -628,6 +628,13 @@ bool isVarArg) const override; bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override; + + + virtual bool useCSRInsteadOfSplit(const LiveInterval &LI) const override; + + virtual unsigned getNumberOfTradableSplitsAgainstCSR() const override; + + virtual unsigned getNumberOfTradableSpillsAgainstCSR() const override; }; namespace AArch64 { Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10778,3 +10778,21 @@ return 3 * getPointerTy(DL).getSizeInBits() + 2 * 32; } + +// If the live interval can be spilled, we'd prefer to do so. +bool AArch64TargetLowering::useCSRInsteadOfSplit(const LiveInterval &LI) const { + return !LI.isSpillable(); +} + +// The number of splits in user blocks which can be traded against the spill of +// the CSR in the entry block when detering the first use of CSR is prefered. +unsigned AArch64TargetLowering::getNumberOfTradableSplitsAgainstCSR() const { + return 32; +} + +// The number of spills in user blocks which can be traded against the spill of +// the CSR in the entry block when detering the first use of CSR is prefered. +unsigned AArch64TargetLowering::getNumberOfTradableSpillsAgainstCSR() const { + return 1; +} + Index: test/CodeGen/AArch64/csr-split.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/csr-split.ll @@ -0,0 +1,38 @@ +; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-linaro-linux-gnueabi" + +; After splitting, instead of assigning CSR, ShrinkWrap is enabled and prologue +; are moved from entry to BB#1. + +define i32 @test1(i32 %x, i32 %y, i32* nocapture %P) local_unnamed_addr { +; CHECK-LABEL: test1: +; CHECK-LABEL: BB#0: +; CHECK-NOT: stp +; CHECK-LABEL: BB#1: +; CHECK: stp x{{[0-9]+}}, x{{[0-9]+}}, [sp, {{.*}}] + +entry: + %idxprom = sext i32 %x to i64 + %arrayidx = getelementptr inbounds i32, i32* %P, i64 %idxprom + %0 = load i32, i32* %arrayidx, align 4 + %idxprom1 = sext i32 %y to i64 + %arrayidx2 = getelementptr inbounds i32, i32* %P, i64 %idxprom1 + %1 = load i32, i32* %arrayidx2, align 4 + %add = add nsw i32 %1, %0 + %cmp = icmp eq i32 %add, 0 + br i1 %cmp, label %cleanup, label %if.end + +if.end: ; preds = %entry + store i32 %add, i32* %P, align 4 + tail call void @func() + br label %cleanup + +cleanup: ; preds = %entry, %if.end + %retval.0 = phi i32 [ %x, %if.end ], [ 0, %entry ] + ret i32 %retval.0 +} + +declare void @func() +