diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2076,6 +2076,13 @@ /// splitting. The criteria for if a function can be split may vary by target. virtual bool isFunctionSafeToSplit(const MachineFunction &MF) const; + /// Return true if the MachineBasicBlock can safely be split to the cold + /// section. On AArch64, certain instructions may cause a block to be unsafe + /// to split to the cold section. + virtual bool isMBBSafeToSplitToCold(const MachineBasicBlock &MBB) const { + return true; + } + /// Produce the expression describing the \p MI loading a value into /// the physical register \p Reg. This hook should only be used with /// \p MIs belonging to VReg-less functions. diff --git a/llvm/lib/CodeGen/MachineFunctionSplitter.cpp b/llvm/lib/CodeGen/MachineFunctionSplitter.cpp --- a/llvm/lib/CodeGen/MachineFunctionSplitter.cpp +++ b/llvm/lib/CodeGen/MachineFunctionSplitter.cpp @@ -109,6 +109,12 @@ const MachineBlockFrequencyInfo *MBFI, ProfileSummaryInfo *PSI) { std::optional Count = MBFI->getBlockProfileCount(&MBB); + + // Temporary hack to cope with AArch64's jump table encoding + const TargetInstrInfo &TII = *MBB.getParent()->getSubtarget().getInstrInfo(); + if (!TII.isMBBSafeToSplitToCold(MBB)) + return false; + // For instrumentation profiles and sample profiles, we use different ways // to judge whether a block is cold and should be split. if (PSI->hasInstrumentationProfile() || PSI->hasCSInstrumentationProfile()) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -338,6 +338,8 @@ bool isFunctionSafeToSplit(const MachineFunction &MF) const override; + bool isMBBSafeToSplitToCold(const MachineBasicBlock &MBB) const override; + std::optional describeLoadedValue(const MachineInstr &MI, Register Reg) const override; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -8406,6 +8406,36 @@ return TargetInstrInfo::isFunctionSafeToSplit(MF); } +bool AArch64InstrInfo::isMBBSafeToSplitToCold( + const MachineBasicBlock &MBB) const { + // Because jump tables are label-relative instead of table-relative, they all + // must be in the same section or relocation fixup handling will fail. + + // Check if MBB is a jump table target + const MachineJumpTableInfo *MJTI = MBB.getParent()->getJumpTableInfo(); + auto containsMBB = [&MBB](const MachineJumpTableEntry &JTE) { + return llvm::is_contained(JTE.MBBs, &MBB); + }; + if (MJTI != nullptr && llvm::any_of(MJTI->getJumpTables(), containsMBB)) + return false; + + // Check if MBB contains a jump table lookup + for (const MachineInstr &MI : MBB) { + switch (MI.getOpcode()) { + case TargetOpcode::G_BRJT: + case AArch64::JumpTableDest32: + case AArch64::JumpTableDest16: + case AArch64::JumpTableDest8: + return false; + default: + continue; + } + } + + // MBB isn't a special case, so it's safe to be split to the cold section. + return true; +} + std::optional AArch64InstrInfo::describeLoadedValue(const MachineInstr &MI, Register Reg) const { diff --git a/llvm/test/CodeGen/Generic/machine-function-splitter.ll b/llvm/test/CodeGen/Generic/machine-function-splitter.ll --- a/llvm/test/CodeGen/Generic/machine-function-splitter.ll +++ b/llvm/test/CodeGen/Generic/machine-function-splitter.ll @@ -517,6 +517,91 @@ ret i32 %tmp2 } +define i32 @foo18(i32 %in) !prof !14 !section_prefix !15 { +;; Check that a cold block targeted by a jump table is not split +;; on AArch64. +; MFS-DEFAULTS-LABEL: foo18 +; MFS-DEFAULTS: .section .text.split.foo18 +; MFS-DEFAULTS-NEXT: foo18.cold: +; MFS-DEFAULTS-SAME: %common.ret +; MFS-DEFAULTS-X86-DAG: jmp qux +; MFS-DEFAULTS-X86-DAG: jmp bam +; MFS-DEFAULTS-AARCH64-NOT: b bar +; MFS-DEFAULTS-AARCH64-NOT: b baz +; MFS-DEFAULTS-AARCH64-NOT: b qux +; MFS-DEFAULTS-AARCH64-NOT: b bam + + switch i32 %in, label %common.ret [ + i32 0, label %hot1 + i32 1, label %hot2 + i32 2, label %cold1 + i32 3, label %cold2 + ], !prof !28 + +common.ret: ; preds = %0 + ret i32 0 + +hot1: ; preds = %0 + %1 = tail call i32 @bar() + ret i32 %1 + +hot2: ; preds = %0 + %2 = tail call i32 @baz() + ret i32 %2 + +cold1: ; preds = %0 + %3 = tail call i32 @bam() + ret i32 %3 + +cold2: ; preds = %0 + %4 = tail call i32 @qux() + ret i32 %4 +} + +define i32 @foo19(i32 %in) !prof !14 !section_prefix !15 { +;; Check that a cold block that contains a jump table dispatch is +;; not split on AArch64. +; MFS-DEFAULTS-LABEL: foo19 +; MFS-DEFAULTS: .section .text.split.foo19 +; MFS-DEFAULTS-NEXT: foo19.cold: +; MFS-DEFAULTS-X86: .LJTI18_0 +; MFS-DEFAULTS-AARCH64-NOT: .LJTI18_0 +; MFS-DEFAULTS: .section .rodata +; MFS-DEFAULTS: .LJTI18_0 + %cmp = icmp sgt i32 %in, 3 + br i1 %cmp, label %hot, label %cold_switch, !prof !17 + +hot: ; preds = %0 +ret i32 1 + +cold_switch: ; preds = %0 + switch i32 %in, label %common.ret [ + i32 0, label %hot1 + i32 1, label %hot2 + i32 2, label %cold1 + i32 3, label %cold2 + ], !prof !28 + +common.ret: ; preds = %0 + ret i32 0 + +hot1: ; preds = %0 + %1 = tail call i32 @bar() + ret i32 %1 + +hot2: ; preds = %0 + %2 = tail call i32 @baz() + ret i32 %2 + +cold1: ; preds = %0 + %3 = tail call i32 @bam() + ret i32 %3 + +cold2: ; preds = %0 + %4 = tail call i32 @qux() + ret i32 %4 +} + declare i32 @bar() declare i32 @baz() declare i32 @bam() @@ -557,3 +642,4 @@ !25 = !{!"branch_weights", i32 0, i32 7000} !26 = !{!"branch_weights", i32 1000, i32 6000} !27 = !{!"function_entry_count", i64 10000} +!28 = !{!"branch_weights", i32 0, i32 4000, i32 4000, i32 0, i32 0}