diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp --- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp +++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp @@ -18,17 +18,24 @@ #include "RISCVTargetMachine.h" #include "TargetInfo/RISCVTargetInfo.h" #include "llvm/ADT/Statistic.h" +#include "llvm/BinaryFormat/ELF.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCContext.h" #include "llvm/MC/MCInst.h" +#include "llvm/MC/MCInstBuilder.h" +#include "llvm/MC/MCObjectFileInfo.h" +#include "llvm/MC/MCSectionELF.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h" + using namespace llvm; #define DEBUG_TYPE "asm-printer" @@ -61,6 +68,11 @@ bool emitPseudoExpansionLowering(MCStreamer &OutStreamer, const MachineInstr *MI); + typedef std::tuple HwasanMemaccessTuple; + std::map HwasanMemaccessSymbols; + void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI); + void EmitHwasanMemaccessSymbols(Module &M); + // Wrapper needed for tblgenned pseudo lowering. bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp) const { return lowerRISCVMachineOperandToMCOperand(MO, MCOp, *this); @@ -99,6 +111,12 @@ return; MCInst TmpInst; + + if (MI->getOpcode() == RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES) { + LowerHWASAN_CHECK_MEMACCESS(*MI); + return; + } + if (!lowerRISCVMachineInstrToMCInst(MI, TmpInst, *this)) EmitToStreamer(*OutStreamer, TmpInst); } @@ -195,6 +213,7 @@ if (TM.getTargetTriple().isOSBinFormatELF()) RTS.finishAttributeSection(); + EmitHwasanMemaccessSymbols(M); } void RISCVAsmPrinter::emitAttributes() { @@ -215,3 +234,256 @@ RegisterAsmPrinter X(getTheRISCV32Target()); RegisterAsmPrinter Y(getTheRISCV64Target()); } + +void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) { + Register Reg = MI.getOperand(0).getReg(); + uint32_t AccessInfo = MI.getOperand(1).getImm(); + MCSymbol *&Sym = + HwasanMemaccessSymbols[HwasanMemaccessTuple(Reg, AccessInfo)]; + if (!Sym) { + // FIXME: Make this work on non-ELF. + if (!TM.getTargetTriple().isOSBinFormatELF()) + report_fatal_error("llvm.hwasan.check.memaccess only supported on ELF"); + + std::string SymName = "__hwasan_check_x" + utostr(Reg - RISCV::X0) + "_" + + utostr(AccessInfo) + "_short"; + Sym = OutContext.getOrCreateSymbol(SymName); + } + auto Res = MCSymbolRefExpr::create(Sym, MCSymbolRefExpr::VK_None, OutContext); + auto Expr = RISCVMCExpr::create(Res, RISCVMCExpr::VK_RISCV_CALL, OutContext); + + EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr)); +} + +void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) { + if (HwasanMemaccessSymbols.empty()) + return; + + const Triple &TT = TM.getTargetTriple(); + assert(TT.isOSBinFormatELF()); + std::unique_ptr STI( + TM.getTarget().createMCSubtargetInfo(TT.str(), "", "")); + assert(STI && "Unable to create subtarget info"); + + MCSymbol *HwasanTagMismatchV2Sym = + OutContext.getOrCreateSymbol("__hwasan_tag_mismatch_v2"); + + const MCSymbolRefExpr *HwasanTagMismatchV2Ref = + MCSymbolRefExpr::create(HwasanTagMismatchV2Sym, OutContext); + + for (auto &P : HwasanMemaccessSymbols) { + unsigned Reg = std::get<0>(P.first); + uint32_t AccessInfo = std::get<1>(P.first); + const MCSymbolRefExpr *HwasanTagMismatchRef = HwasanTagMismatchV2Ref; + MCSymbol *Sym = P.second; + + unsigned Size = + 1 << ((AccessInfo >> HWASanAccessInfo::AccessSizeShift) & 0xf); + OutStreamer->switchSection(OutContext.getELFSection( + ".text.hot", ELF::SHT_PROGBITS, + ELF::SHF_EXECINSTR | ELF::SHF_ALLOC | ELF::SHF_GROUP, 0, Sym->getName(), + /*IsComdat=*/true)); + + OutStreamer->emitSymbolAttribute(Sym, MCSA_ELF_TypeFunction); + OutStreamer->emitSymbolAttribute(Sym, MCSA_Weak); + OutStreamer->emitSymbolAttribute(Sym, MCSA_Hidden); + OutStreamer->emitLabel(Sym); + + // Extract shadow offset from ptr + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SLLI).addReg(RISCV::X6).addReg(Reg).addImm(8), + *STI); + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SRLI) + .addReg(RISCV::X6) + .addReg(RISCV::X6) + .addImm(12), + *STI); + // load shadow tag in X6, X5 contains shadow base + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADD) + .addReg(RISCV::X6) + .addReg(RISCV::X5) + .addReg(RISCV::X6), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0), + *STI); + // Extract tag from X5 and compare it with loaded tag from shadow + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SRLI).addReg(RISCV::X7).addReg(Reg).addImm(56), + *STI); + MCSymbol *HandleMismatchOrPartialSym = OutContext.createTempSymbol(); + // X7 contains tag from memory, while X6 contains tag from the pointer + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BNE) + .addReg(RISCV::X7) + .addReg(RISCV::X6) + .addExpr(MCSymbolRefExpr::create(HandleMismatchOrPartialSym, + OutContext)), + *STI); + MCSymbol *ReturnSym = OutContext.createTempSymbol(); + OutStreamer->emitLabel(ReturnSym); + OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X1) + .addImm(0), + *STI); + OutStreamer->emitLabel(HandleMismatchOrPartialSym); + + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X28) + .addReg(RISCV::X0) + .addImm(16), + *STI); + MCSymbol *HandleMismatchSym = OutContext.createTempSymbol(); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BGEU) + .addReg(RISCV::X6) + .addReg(RISCV::X28) + .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)), + *STI); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF), + *STI); + + if (Size != 1) + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X28) + .addReg(RISCV::X28) + .addImm(Size - 1), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BGE) + .addReg(RISCV::X28) + .addReg(RISCV::X6) + .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)), + *STI); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BEQ) + .addReg(RISCV::X6) + .addReg(RISCV::X7) + .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)), + *STI); + + OutStreamer->emitLabel(HandleMismatchSym); + + // | Previous stack frames... | + // +=================================+ <-- [SP + 256] + // | ... | + // | | + // | Stack frame space for x12 - x31.| + // | | + // | ... | + // +---------------------------------+ <-- [SP + 96] + // | Saved x11(arg1), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 88] + // | Saved x10(arg0), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 80] + // | | + // | Stack frame space for x9. | + // +---------------------------------+ <-- [SP + 72] + // | | + // | Saved x8(fp), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 64] + // | ... | + // | | + // | Stack frame space for x2 - x7. | + // | | + // | ... | + // +---------------------------------+ <-- [SP + 16] + // | Return address (x1) for caller | + // | of __hwasan_check_*. | + // +---------------------------------+ <-- [SP + 8] + // | Reserved place for x0, possibly | + // | junk, since we don't save it. | + // +---------------------------------+ <-- [x2 / SP] + + // Adjust sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X2) + .addReg(RISCV::X2) + .addImm(-256), + *STI); + + // store x10(arg0) by new sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD) + .addReg(RISCV::X10) + .addReg(RISCV::X2) + .addImm(8 * 10), + *STI); + // store x11(arg1) by new sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD) + .addReg(RISCV::X11) + .addReg(RISCV::X2) + .addImm(8 * 11), + *STI); + + // store x8(fp) by new sp + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SD).addReg(RISCV::X8).addReg(RISCV::X2).addImm(8 * + 8), + *STI); + // store x1(ra) by new sp + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SD).addReg(RISCV::X1).addReg(RISCV::X2).addImm(1 * + 8), + *STI); + if (Reg != RISCV::X10) + OutStreamer->emitInstruction(MCInstBuilder(RISCV::OR) + .addReg(RISCV::X10) + .addReg(RISCV::X0) + .addReg(Reg), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X11) + .addReg(RISCV::X0) + .addImm(AccessInfo & HWASanAccessInfo::RuntimeMask), + *STI); + + // Intentionally load the GOT entry and branch to it, rather than possibly + // late binding the function, which may clobber the registers before we have + // a chance to save them. + RISCVMCExpr::VariantKind VKHi; + unsigned SecondOpcode; + if (OutContext.getObjectFileInfo()->isPositionIndependent()) { + SecondOpcode = RISCV::LD; + VKHi = RISCVMCExpr::VK_RISCV_GOT_HI; + } else { + SecondOpcode = RISCV::ADDI; + VKHi = RISCVMCExpr::VK_RISCV_PCREL_HI; + } + auto ExprHi = RISCVMCExpr::create(HwasanTagMismatchRef, VKHi, OutContext); + + MCSymbol *TmpLabel = + OutContext.createTempSymbol("pcrel_hi", /* AlwaysAddSuffix */ true); + OutStreamer->emitLabel(TmpLabel); + const MCExpr *ExprLo = + RISCVMCExpr::create(MCSymbolRefExpr::create(TmpLabel, OutContext), + RISCVMCExpr::VK_RISCV_PCREL_LO, OutContext); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X6).addExpr(ExprHi), *STI); + OutStreamer->emitInstruction(MCInstBuilder(SecondOpcode) + .addReg(RISCV::X6) + .addReg(RISCV::X6) + .addExpr(ExprLo), + *STI); + + OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X6) + .addImm(0), + *STI); + } +} diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -42,6 +42,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -579,7 +580,8 @@ UseShortGranules = ClUseShortGranules.getNumOccurrences() ? ClUseShortGranules : NewRuntime; OutlinedChecks = - TargetTriple.isAArch64() && TargetTriple.isOSBinFormatELF() && + (TargetTriple.isAArch64() || TargetTriple.isRISCV64()) && + TargetTriple.isOSBinFormatELF() && (ClInlineAllChecks.getNumOccurrences() ? !ClInlineAllChecks : !Recover); if (ClMatchAllTag.getNumOccurrences()) { @@ -602,7 +604,10 @@ bool InstrumentGlobals = ClGlobals.getNumOccurrences() ? ClGlobals : NewRuntime; - if (InstrumentGlobals && !UsePageAliases) + // Currently we do not instrumentation of globals for RISCV + // The reason is that the existing memory models does not allow us + // to use tagged pointers in la/lla expressions + if (InstrumentGlobals && !UsePageAliases && !TargetTriple.isRISCV64()) instrumentGlobals(); bool InstrumentPersonalityFunctions = @@ -791,7 +796,8 @@ } void HWAddressSanitizer::untagPointerOperand(Instruction *I, Value *Addr) { - if (TargetTriple.isAArch64() || TargetTriple.getArch() == Triple::x86_64) + if (TargetTriple.isAArch64() || TargetTriple.getArch() == Triple::x86_64 || + TargetTriple.isRISCV64()) return; IRBuilder<> IRB(I); @@ -828,8 +834,9 @@ IRBuilder<> IRB(InsertBefore); Module *M = IRB.GetInsertBlock()->getParent()->getParent(); Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); + // In case of RISC-V always use shortgranules IRB.CreateCall(Intrinsic::getDeclaration( - M, UseShortGranules + M, UseShortGranules || TargetTriple.isRISCV64() ? Intrinsic::hwasan_check_memaccess_shortgranules : Intrinsic::hwasan_check_memaccess), {ShadowBase, Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); @@ -909,6 +916,13 @@ "{x0}", /*hasSideEffects=*/true); break; + case Triple::riscv64: + // The signal handler will find the data address in x10. + Asm = InlineAsm::get( + FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + "ebreak\naddiw x0, x0, " + itostr(0x40 + AccessInfo), "{x10}", "{x11}", + /*hasSideEffects=*/true); + break; default: report_fatal_error("unsupported architecture"); }