diff --git a/bolt/lib/Target/RISCV/RISCVMCPlusBuilder.cpp b/bolt/lib/Target/RISCV/RISCVMCPlusBuilder.cpp --- a/bolt/lib/Target/RISCV/RISCVMCPlusBuilder.cpp +++ b/bolt/lib/Target/RISCV/RISCVMCPlusBuilder.cpp @@ -110,23 +110,9 @@ "Invalid instruction"); unsigned SymOpIndex; - - switch (Inst.getOpcode()) { - default: - llvm_unreachable("not implemented"); - case RISCV::C_J: - SymOpIndex = 0; - break; - case RISCV::JAL: - case RISCV::C_BEQZ: - case RISCV::C_BNEZ: - SymOpIndex = 1; - break; - case RISCV::BEQ: - case RISCV::BGE: - SymOpIndex = 2; - break; - } + auto Result = getSymbolRefOperandNum(Inst, SymOpIndex); + (void)Result; + assert(Result && "unimplemented branch"); Inst.getOperand(SymOpIndex) = MCOperand::createExpr( MCSymbolRefExpr::create(TBB, MCSymbolRefExpr::VK_None, *Ctx)); @@ -237,8 +223,50 @@ return true; } + bool getSymbolRefOperandNum(const MCInst &Inst, unsigned &OpNum) const { + switch (Inst.getOpcode()) { + default: + return false; + case RISCV::C_J: + OpNum = 0; + return true; + case RISCV::JAL: + case RISCV::C_BEQZ: + case RISCV::C_BNEZ: + OpNum = 1; + return true; + case RISCV::BEQ: + case RISCV::BGE: + case RISCV::BGEU: + case RISCV::BNE: + case RISCV::BLT: + case RISCV::BLTU: + OpNum = 2; + return true; + } + } + + const MCSymbol *getTargetSymbol(const MCExpr *Expr) const override { + auto *RISCVExpr = dyn_cast(Expr); + if (RISCVExpr && RISCVExpr->getSubExpr()) + return getTargetSymbol(RISCVExpr->getSubExpr()); + + auto *BinExpr = dyn_cast(Expr); + if (BinExpr) + return getTargetSymbol(BinExpr->getLHS()); + + auto *SymExpr = dyn_cast(Expr); + if (SymExpr && SymExpr->getKind() == MCSymbolRefExpr::VK_None) + return &SymExpr->getSymbol(); + + return nullptr; + } + const MCSymbol *getTargetSymbol(const MCInst &Inst, unsigned OpNum = 0) const override { + if (!OpNum && !getSymbolRefOperandNum(Inst, OpNum)) + return nullptr; + const MCOperand &Op = Inst.getOperand(OpNum); if (!Op.isExpr()) return nullptr;