diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2066,6 +2066,13 @@ return false; } + /// Return true if the MachineBasicBlock can safely be split to the cold + /// section. On AArch64, certain instructions may cause a block to be unsafe + /// to split to the cold section. + virtual bool isMBBSafeToSplitToCold(const MachineBasicBlock &MBB) const { + return true; + } + /// Produce the expression describing the \p MI loading a value into /// the physical register \p Reg. This hook should only be used with /// \p MIs belonging to VReg-less functions. diff --git a/llvm/lib/CodeGen/MachineFunctionSplitter.cpp b/llvm/lib/CodeGen/MachineFunctionSplitter.cpp --- a/llvm/lib/CodeGen/MachineFunctionSplitter.cpp +++ b/llvm/lib/CodeGen/MachineFunctionSplitter.cpp @@ -35,6 +35,7 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/IR/Function.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" @@ -108,6 +109,12 @@ const MachineBlockFrequencyInfo *MBFI, ProfileSummaryInfo *PSI) { std::optional Count = MBFI->getBlockProfileCount(&MBB); + + // Temporary hack to cope with AArch64's jump table encoding + const TargetInstrInfo &TII = *MBB.getParent()->getSubtarget().getInstrInfo(); + if (!TII.isMBBSafeToSplitToCold(MBB)) + return false; + // For instrumentation profiles and sample profiles, we use different ways // to judge whether a block is cold and should be split. if (PSI->hasInstrumentationProfile() || PSI->hasCSInstrumentationProfile()) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -325,6 +325,8 @@ std::optional isAddImmediate(const MachineInstr &MI, Register Reg) const override; + bool isMBBSafeToSplitToCold(const MachineBasicBlock &MBB) const override; + std::optional describeLoadedValue(const MachineInstr &MI, Register Reg) const override; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -8368,6 +8368,38 @@ return std::nullopt; } +bool AArch64InstrInfo::isMBBSafeToSplitToCold( + const MachineBasicBlock &MBB) const { + // Because jump tables are label-relative instead of table-relative, they all + // must be in the same section or else relocation fixup handling will throw a + // fit. + auto isJumpTableLookup = [](const MachineInstr &MI) { + switch (MI.getOpcode()) { + case TargetOpcode::G_BRJT: + case AArch64::JumpTableDest32: + case AArch64::JumpTableDest16: + case AArch64::JumpTableDest8: + return true; + default: + return false; + } + }; + auto containsJumpTableLookup = [&](const MachineBasicBlock &MBB) { + return llvm::any_of(MBB, isJumpTableLookup); + }; + + auto isInJumpTable = [&](const MachineJumpTableEntry &JTE) { + return llvm::find(JTE.MBBs, &MBB) != JTE.MBBs.end(); + }; + auto isJumpTableTarget = [&](const MachineBasicBlock &MBB) { + const MachineJumpTableInfo *MJTI = MBB.getParent()->getJumpTableInfo(); + return MJTI != nullptr && + llvm::any_of(MJTI->getJumpTables(), isInJumpTable); + }; + + return !containsJumpTableLookup(MBB) && !isJumpTableTarget(MBB); +} + std::optional AArch64InstrInfo::describeLoadedValue(const MachineInstr &MI, Register Reg) const { diff --git a/llvm/test/CodeGen/AArch64/machine-function-splitter.mir b/llvm/test/CodeGen/AArch64/machine-function-splitter.mir --- a/llvm/test/CodeGen/AArch64/machine-function-splitter.mir +++ b/llvm/test/CodeGen/AArch64/machine-function-splitter.mir @@ -17,6 +17,35 @@ %7 = tail call i32 @qux() ret void } + + ; Function Attrs: nounwind + define i32 @nosplit_jumptable(i32 %in) #0 !prof !14 !section_prefix !15 { + switch i32 %in, label %common.ret [ + i32 0, label %hot1 + i32 1, label %hot2 + i32 2, label %cold1 + i32 3, label %cold2 + ], !prof !17 + + common.ret: ; preds = %0 + ret i32 0 + + hot1: ; preds = %0 + %1 = tail call i32 @bar() + ret i32 %1 + + hot2: ; preds = %0 + %2 = tail call i32 @baz() + ret i32 %2 + + cold1: ; preds = %0 + %3 = tail call i32 @bam() + ret i32 %3 + + cold2: ; preds = %0 + %4 = tail call i32 @qux() + ret i32 %4 + } declare i32 @bar() @@ -47,6 +76,7 @@ !14 = !{!"function_entry_count", i64 9000} !15 = !{!"function_section_prefix", !"hot"} !16 = !{!"branch_weights", i32 7000, i32 0} + !17 = !{!"branch_weights", i32 1000, i32 4000, i32 4000, i32 0, i32 0} ... --- @@ -82,3 +112,51 @@ TCRETURNdi @qux, 0, csr_aarch64_aapcs, implicit $sp ... +--- +name: nosplit_jumptable +machineFunctionInfo: + hasRedZone: false +jumpTable: + kind: block-address + entries: + - id: 0 + blocks: [ '%bb.3', '%bb.4', '%bb.5', '%bb.6' ] +body: | + ; CHECK-LABEL: name: nosplit_jumptable + ; COM: Check that a cold block targeted by a jump table is not split. + ; CHECK-NOT: bbsections Cold + ; CHECK: TCRETURNdi @qux + bb.0 (%ir-block.0): + successors: %bb.2(0x0e38e38e), %bb.1(0x71c71c72) + liveins: $w0 + + dead $wzr = SUBSWri renamable $w0, 3, 0, implicit-def $nzcv + Bcc 8, %bb.2, implicit killed $nzcv + + bb.1 (%ir-block.0): + successors: %bb.3(0x40000000), %bb.4(0x40000000), %bb.5(0x00000000), %bb.6(0x00000000) + liveins: $w0 + + renamable $w8 = ORRWrs $wzr, killed renamable $w0, 0, implicit-def $x8 + $x9 = ADRP target-flags(aarch64-page) %jump-table.0 + renamable $x9 = ADDXri killed $x9, target-flags(aarch64-pageoff, aarch64-nc) %jump-table.0, 0 + early-clobber renamable $x10, dead early-clobber renamable $x11 = JumpTableDest32 killed renamable $x9, killed renamable $x8, %jump-table.0 + BR killed renamable $x10 + + bb.3.hot1: + TCRETURNdi @bar, 0, csr_aarch64_aapcs, implicit $sp + + bb.4.hot2: + TCRETURNdi @baz, 0, csr_aarch64_aapcs, implicit $sp + + bb.2.common.ret: + $w0 = ORRWrs $wzr, $wzr, 0 + RET undef $lr, implicit killed $w0 + + bb.5.cold1: + TCRETURNdi @bam, 0, csr_aarch64_aapcs, implicit $sp + + bb.6.cold2: + TCRETURNdi @qux, 0, csr_aarch64_aapcs, implicit $sp + +...