diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -386,6 +386,10 @@ CZERO_EQZ, // vt.maskc for XVentanaCondOps. CZERO_NEZ, // vt.maskcn for XVentanaCondOps. + /// Software guarded BRIND node. Operand 0 is the chain operand and + /// operand 1 is the target address. + SW_GUARDED_BRIND, + // FP to 32 bit int conversions for RV64. These are used to keep track of the // result being sign extended to 64 bit. These saturate out of range inputs. STRICT_FCVT_W_RV64 = ISD::FIRST_TARGET_STRICTFP_OPCODE, @@ -802,6 +806,9 @@ bool supportKCFIBundles() const override { return true; } + SDValue expandIndirectJTBranch(const SDLoc &dl, SDValue Value, SDValue Addr, + SelectionDAG &DAG) const override; + MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB, MachineBasicBlock::instr_iterator &MBBI, const TargetInstrInfo *TII) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -17934,6 +17934,18 @@ return getTargetMMOFlags(NodeX) == getTargetMMOFlags(NodeY); } +SDValue RISCVTargetLowering::expandIndirectJTBranch(const SDLoc &dl, + SDValue Value, SDValue Addr, + SelectionDAG &DAG) const { + if (Subtarget.hasStdExtZicfilp()) { + // When Zicfilp enabled, we need to use software guarded branch for jump + // table branch. + return DAG.getNode(RISCVISD::SW_GUARDED_BRIND, dl, MVT::Other, Value, Addr); + } + + return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, DAG); +} + namespace llvm::RISCVVIntrinsicsTable { #define GET_RISCVVIntrinsicsTable_IMPL diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -44,6 +44,7 @@ def SDT_RISCVIntShiftDOpW : SDTypeProfile<1, 3, [ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisVT<0, i64>, SDTCisVT<3, i64> ]>; +def SDT_RISCVSWGuardedBrind : SDTypeProfile<0, -1, [SDTCisVT<0, iPTR>]>; // Target-independent nodes, but with target-specific formats. def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_CallSeqStart, @@ -67,6 +68,8 @@ def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall, [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue, SDNPVariadic]>; +def riscv_sw_guarded_brind : SDNode<"RISCVISD::SW_GUARDED_BRIND", + SDT_RISCVSWGuardedBrind, [SDNPHasChain]>; def riscv_sllw : SDNode<"RISCVISD::SLLW", SDT_RISCVIntBinOpW>; def riscv_sraw : SDNode<"RISCVISD::SRAW", SDT_RISCVIntBinOpW>; def riscv_srlw : SDNode<"RISCVISD::SRLW", SDT_RISCVIntBinOpW>; @@ -1554,6 +1557,13 @@ def PseudoBRIND : Pseudo<(outs), (ins GPRJALR:$rs1, simm12:$imm12), []>, PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>; +let Predicates = [HasStdExtZicfilp] in { +let isBarrier = 1, isBranch = 1, isIndirectBranch = 1, isTerminator = 1 in +def PseudoBRINDX7 : Pseudo<(outs), (ins GPRX7:$rs1, simm12:$imm12), []>, + PseudoInstExpansion<(JALR X0, GPR:$rs1, simm12:$imm12)>; + +} + def : Pat<(brind GPRJALR:$rs1), (PseudoBRIND GPRJALR:$rs1, 0)>; def : Pat<(brind (add GPRJALR:$rs1, simm12:$imm12)), (PseudoBRIND GPRJALR:$rs1, simm12:$imm12)>; @@ -1945,6 +1955,12 @@ (AddiPairImmSmall AddiPair:$rs2))>; } +let Predicates = [HasStdExtZicfilp] in { +def : Pat<(riscv_sw_guarded_brind GPRX7:$rs1), (PseudoBRINDX7 GPRX7:$rs1, 0)>; +def : Pat<(riscv_sw_guarded_brind (add GPRX7:$rs1, simm12:$imm12)), + (PseudoBRINDX7 GPRX7:$rs1, simm12:$imm12)>; +} + //===----------------------------------------------------------------------===// // Standard extensions //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td @@ -142,6 +142,8 @@ def GPRNoX0X2 : GPRRegisterClass<(sub GPR, X0, X2)>; +def GPRX7 : GPRRegisterClass<(add X7)>; + // Don't use X1 or X5 for JALR since that is a hint to pop the return address // stack on some microarchitectures. Also remove the reserved registers X0, X2, // X3, and X4 as it reduces the number of register classes that get synthesized diff --git a/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/jumptable-swguarded.ll @@ -0,0 +1,105 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple riscv32 -mattr=+experimental-zicfilp < %s | FileCheck %s +; RUN: llc -mtriple riscv64 -mattr=+experimental-zicfilp < %s | FileCheck %s +; RUN: llc -mtriple riscv32 < %s | FileCheck %s --check-prefix=NO-ZICFILP +; RUN: llc -mtriple riscv64 < %s | FileCheck %s --check-prefix=NO-ZICFILP + +; Test using t2 to achieve software guarded branch. +define void @above_threshold(i32 signext %in, ptr %out) nounwind { +; CHECK-LABEL: above_threshold: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: addi a0, a0, -1 +; CHECK-NEXT: li a2, 5 +; CHECK-NEXT: bltu a2, a0, .LBB0_9 +; CHECK-NEXT: # %bb.1: # %entry +; CHECK-NEXT: slli a0, a0, 2 +; CHECK-NEXT: lui a2, %hi(.LJTI0_0) +; CHECK-NEXT: addi a2, a2, %lo(.LJTI0_0) +; CHECK-NEXT: add a0, a0, a2 +; CHECK-NEXT: lw t2, 0(a0) +; CHECK-NEXT: jr t2 +; CHECK-NEXT: .LBB0_2: # %bb1 +; CHECK-NEXT: li a0, 4 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_3: # %bb2 +; CHECK-NEXT: li a0, 3 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_4: # %bb3 +; CHECK-NEXT: li a0, 2 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_5: # %bb4 +; CHECK-NEXT: li a0, 1 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_6: # %bb5 +; CHECK-NEXT: li a0, 100 +; CHECK-NEXT: j .LBB0_8 +; CHECK-NEXT: .LBB0_7: # %bb6 +; CHECK-NEXT: li a0, 200 +; CHECK-NEXT: .LBB0_8: # %exit +; CHECK-NEXT: sw a0, 0(a1) +; CHECK-NEXT: .LBB0_9: # %exit +; CHECK-NEXT: ret +; +; NO-ZICFILP-LABEL: above_threshold: +; NO-ZICFILP: # %bb.0: # %entry +; NO-ZICFILP-NEXT: addi a0, a0, -1 +; NO-ZICFILP-NEXT: li a2, 5 +; NO-ZICFILP-NEXT: bltu a2, a0, .LBB0_9 +; NO-ZICFILP-NEXT: # %bb.1: # %entry +; NO-ZICFILP-NEXT: slli a0, a0, 2 +; NO-ZICFILP-NEXT: lui a2, %hi(.LJTI0_0) +; NO-ZICFILP-NEXT: addi a2, a2, %lo(.LJTI0_0) +; NO-ZICFILP-NEXT: add a0, a0, a2 +; NO-ZICFILP-NEXT: lw a0, 0(a0) +; NO-ZICFILP-NEXT: jr a0 +; NO-ZICFILP-NEXT: .LBB0_2: # %bb1 +; NO-ZICFILP-NEXT: li a0, 4 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_3: # %bb2 +; NO-ZICFILP-NEXT: li a0, 3 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_4: # %bb3 +; NO-ZICFILP-NEXT: li a0, 2 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_5: # %bb4 +; NO-ZICFILP-NEXT: li a0, 1 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_6: # %bb5 +; NO-ZICFILP-NEXT: li a0, 100 +; NO-ZICFILP-NEXT: j .LBB0_8 +; NO-ZICFILP-NEXT: .LBB0_7: # %bb6 +; NO-ZICFILP-NEXT: li a0, 200 +; NO-ZICFILP-NEXT: .LBB0_8: # %exit +; NO-ZICFILP-NEXT: sw a0, 0(a1) +; NO-ZICFILP-NEXT: .LBB0_9: # %exit +; NO-ZICFILP-NEXT: ret +entry: + switch i32 %in, label %exit [ + i32 1, label %bb1 + i32 2, label %bb2 + i32 3, label %bb3 + i32 4, label %bb4 + i32 5, label %bb5 + i32 6, label %bb6 + ] +bb1: + store i32 4, ptr %out + br label %exit +bb2: + store i32 3, ptr %out + br label %exit +bb3: + store i32 2, ptr %out + br label %exit +bb4: + store i32 1, ptr %out + br label %exit +bb5: + store i32 100, ptr %out + br label %exit +bb6: + store i32 200, ptr %out + br label %exit +exit: + ret void +}