diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -86,10 +86,8 @@ void emitStartOfAsmFile(Module &M) override; void emitJumpTableInfo() override; - void emitJumpTableEntry(const MachineJumpTableInfo *MJTI, - const MachineBasicBlock *MBB, unsigned JTI); - void LowerJumpTableDestSmall(MCStreamer &OutStreamer, const MachineInstr &MI); + void LowerJumpTableDest(MCStreamer &OutStreamer, const MachineInstr &MI); void LowerSTACKMAP(MCStreamer &OutStreamer, StackMaps &SM, const MachineInstr &MI); @@ -793,33 +791,22 @@ emitAlignment(Align(Size)); OutStreamer->emitLabel(GetJTISymbol(JTI)); - for (auto *JTBB : JTBBs) - emitJumpTableEntry(MJTI, JTBB, JTI); - } -} + const MCSymbol *BaseSym = AArch64FI->getJumpTableEntryPCRelSymbol(JTI); + const MCExpr *Base = MCSymbolRefExpr::create(BaseSym, OutContext); -void AArch64AsmPrinter::emitJumpTableEntry(const MachineJumpTableInfo *MJTI, - const MachineBasicBlock *MBB, - unsigned JTI) { - const MCExpr *Value = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext); - auto AFI = MF->getInfo(); - unsigned Size = AFI->getJumpTableEntrySize(JTI); + for (auto *JTBB : JTBBs) { + const MCExpr *Value = + MCSymbolRefExpr::create(JTBB->getSymbol(), OutContext); - if (Size == 4) { - // .word LBB - LJTI - const TargetLowering *TLI = MF->getSubtarget().getTargetLowering(); - const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF, JTI, OutContext); - Value = MCBinaryExpr::createSub(Value, Base, OutContext); - } else { - // .byte (LBB - LBB) >> 2 (or .hword) - const MCSymbol *BaseSym = AFI->getJumpTableEntryPCRelSymbol(JTI); - const MCExpr *Base = MCSymbolRefExpr::create(BaseSym, OutContext); - Value = MCBinaryExpr::createSub(Value, Base, OutContext); - Value = MCBinaryExpr::createLShr( - Value, MCConstantExpr::create(2, OutContext), OutContext); - } + // Each entry is: + // .byte/.hword/.word (LBB - Lbase)>>2 + Value = MCBinaryExpr::createSub(Value, Base, OutContext); + Value = MCBinaryExpr::createLShr( + Value, MCConstantExpr::create(2, OutContext), OutContext); - OutStreamer->emitValue(Value, Size); + OutStreamer->emitValue(Value, Size); + } + } } /// Small jump tables contain an unsigned byte or half, representing the offset @@ -831,8 +818,8 @@ /// adr xDest, .LBB0_0 /// ldrb wScratch, [xTable, xEntry] (with "lsl #1" for ldrh). /// add xDest, xDest, xScratch, lsl #2 -void AArch64AsmPrinter::LowerJumpTableDestSmall(llvm::MCStreamer &OutStreamer, - const llvm::MachineInstr &MI) { +void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer, + const llvm::MachineInstr &MI) { Register DestReg = MI.getOperand(0).getReg(); Register ScratchReg = MI.getOperand(1).getReg(); Register ScratchRegW = @@ -840,25 +827,41 @@ Register TableReg = MI.getOperand(2).getReg(); Register EntryReg = MI.getOperand(3).getReg(); int JTIdx = MI.getOperand(4).getIndex(); - bool IsByteEntry = MI.getOpcode() == AArch64::JumpTableDest8; + int Size = AArch64FI->getJumpTableEntrySize(JTIdx); // This has to be first because the compression pass based its reachability // calculations on the start of the JumpTableDest instruction. auto Label = MF->getInfo()->getJumpTableEntryPCRelSymbol(JTIdx); + + // If we don't already have a symbol to use as the base, use the ADR + // instruction itself. + if (!Label) { + Label = MF->getContext().createTempSymbol(); + AArch64FI->setJumpTableEntryInfo(JTIdx, Size, Label); + OutStreamer.emitLabel(Label); + } EmitToStreamer(OutStreamer, MCInstBuilder(AArch64::ADR) .addReg(DestReg) .addExpr(MCSymbolRefExpr::create( Label, MF->getContext()))); // Load the number of instruction-steps to offset from the label. - unsigned LdrOpcode = IsByteEntry ? AArch64::LDRBBroX : AArch64::LDRHHroX; + unsigned LdrOpcode; + switch (Size) { + case 1: LdrOpcode = AArch64::LDRBBroX; break; + case 2: LdrOpcode = AArch64::LDRHHroX; break; + case 4: LdrOpcode = AArch64::LDRSWroX; break; + default: + llvm_unreachable("Unknown jump table size"); + } + EmitToStreamer(OutStreamer, MCInstBuilder(LdrOpcode) - .addReg(ScratchRegW) + .addReg(Size == 4 ? ScratchReg : ScratchRegW) .addReg(TableReg) .addReg(EntryReg) .addImm(0) - .addImm(IsByteEntry ? 0 : 1)); + .addImm(Log2_32(Size))); // Multiply the steps by 4 and add to the already materialized base label // address. @@ -1187,30 +1190,10 @@ return; } - case AArch64::JumpTableDest32: { - // We want: - // ldrsw xScratch, [xTable, xEntry, lsl #2] - // add xDest, xTable, xScratch - unsigned DestReg = MI->getOperand(0).getReg(), - ScratchReg = MI->getOperand(1).getReg(), - TableReg = MI->getOperand(2).getReg(), - EntryReg = MI->getOperand(3).getReg(); - EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX) - .addReg(ScratchReg) - .addReg(TableReg) - .addReg(EntryReg) - .addImm(0) - .addImm(1)); - EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs) - .addReg(DestReg) - .addReg(TableReg) - .addReg(ScratchReg) - .addImm(0)); - return; - } + case AArch64::JumpTableDest32: case AArch64::JumpTableDest16: case AArch64::JumpTableDest8: - LowerJumpTableDestSmall(*OutStreamer, *MI); + LowerJumpTableDest(*OutStreamer, *MI); return; case AArch64::FMOVH0: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6219,6 +6219,9 @@ SDValue Entry = Op.getOperand(2); int JTI = cast(JT.getNode())->getIndex(); + auto *AFI = DAG.getMachineFunction().getInfo(); + AFI->setJumpTableEntryInfo(JTI, 4, nullptr); + SDNode *Dest = DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT, Entry, DAG.getTargetJumpTable(JTI, MVT::i32)); diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -285,15 +285,13 @@ void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; } unsigned getJumpTableEntrySize(int Idx) const { - auto It = JumpTableEntryInfo.find(Idx); - if (It != JumpTableEntryInfo.end()) - return It->second.first; - return 4; + return JumpTableEntryInfo[Idx].first; } MCSymbol *getJumpTableEntryPCRelSymbol(int Idx) const { - return JumpTableEntryInfo.find(Idx)->second.second; + return JumpTableEntryInfo[Idx].second; } void setJumpTableEntryInfo(int Idx, unsigned Size, MCSymbol *PCRelSym) { + JumpTableEntryInfo.resize(Idx+1); JumpTableEntryInfo[Idx] = std::make_pair(Size, PCRelSym); } @@ -354,7 +352,7 @@ MILOHContainer LOHContainerSet; SetOfInstructions LOHRelated; - DenseMap> JumpTableEntryInfo; + SmallVector, 2> JumpTableEntryInfo; }; namespace yaml { diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -2934,6 +2934,8 @@ Register TargetReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass); Register ScratchReg = MRI.createVirtualRegister(&AArch64::GPR64spRegClass); + + MF->getInfo()->setJumpTableEntryInfo(JTI, 4, nullptr); auto JumpTableInst = MIB.buildInstr(AArch64::JumpTableDest32, {TargetReg, ScratchReg}, {JTAddr, Index}) .addJumpTableIndex(JTI); diff --git a/llvm/test/CodeGen/AArch64/jump-table-exynos.ll b/llvm/test/CodeGen/AArch64/jump-table-exynos.ll --- a/llvm/test/CodeGen/AArch64/jump-table-exynos.ll +++ b/llvm/test/CodeGen/AArch64/jump-table-exynos.ll @@ -11,7 +11,17 @@ i32 4, label %lbl4 ] ; CHECK-LABEL: test_jumptable: -; CHECK-NOT: ldrb +; CHECK: adrp [[JTPAGE:x[0-9]+]], .LJTI0_0 +; CHECK: add x[[JT:[0-9]+]], [[JTPAGE]], {{#?}}:lo12:.LJTI0_0 +; CHECK: [[PCREL_LBL:.Ltmp.*]]: +; CHECK-NEXT: adr [[PCBASE:x[0-9]+]], [[PCREL_LBL]] +; CHECK: ldrsw x[[OFFSET:[0-9]+]], [x[[JT]], {{x[0-9]+}}, lsl #2] +; CHECK: add [[DEST:x[0-9]+]], [[PCBASE]], x[[OFFSET]], lsl #2 +; CHECK: br [[DEST]] + + +; CHECK: .LJTI0_0: +; CHECK-NEXT: .word (.LBB{{.*}}-[[PCREL_LBL]])>>2 def: ret i32 0 diff --git a/llvm/test/CodeGen/AArch64/win64-jumptable.ll b/llvm/test/CodeGen/AArch64/win64-jumptable.ll --- a/llvm/test/CodeGen/AArch64/win64-jumptable.ll +++ b/llvm/test/CodeGen/AArch64/win64-jumptable.ll @@ -40,10 +40,10 @@ ; CHECK-NEXT: .seh_endfunclet ; CHECK-NEXT: .p2align 2 ; CHECK-NEXT: .LJTI0_0: -; CHECK: .word .LBB0_2-.LJTI0_0 -; CHECK: .word .LBB0_3-.LJTI0_0 -; CHECK: .word .LBB0_4-.LJTI0_0 -; CHECK: .word .LBB0_5-.LJTI0_0 +; CHECK: .word (.LBB0_2-.Ltmp0)>>2 +; CHECK: .word (.LBB0_3-.Ltmp0)>>2 +; CHECK: .word (.LBB0_4-.Ltmp0)>>2 +; CHECK: .word (.LBB0_5-.Ltmp0)>>2 ; CHECK: .section .xdata,"dr" ; CHECK: .seh_handlerdata ; CHECK: .text