diff --git a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp --- a/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp +++ b/llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp @@ -18,17 +18,24 @@ #include "RISCVTargetMachine.h" #include "TargetInfo/RISCVTargetInfo.h" #include "llvm/ADT/Statistic.h" +#include "llvm/BinaryFormat/ELF.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/MC/MCAsmInfo.h" +#include "llvm/MC/MCContext.h" #include "llvm/MC/MCInst.h" +#include "llvm/MC/MCInstBuilder.h" +#include "llvm/MC/MCObjectFileInfo.h" +#include "llvm/MC/MCSectionELF.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h" + using namespace llvm; #define DEBUG_TYPE "asm-printer" @@ -61,6 +68,11 @@ bool emitPseudoExpansionLowering(MCStreamer &OutStreamer, const MachineInstr *MI); + typedef std::tuple HwasanMemaccessTuple; + std::map HwasanMemaccessSymbols; + void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI); + void EmitHwasanMemaccessSymbols(Module &M); + // Wrapper needed for tblgenned pseudo lowering. bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp) const { return lowerRISCVMachineOperandToMCOperand(MO, MCOp, *this); @@ -99,6 +111,12 @@ return; MCInst TmpInst; + + if (MI->getOpcode() == RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES) { + LowerHWASAN_CHECK_MEMACCESS(*MI); + return; + } + if (!lowerRISCVMachineInstrToMCInst(MI, TmpInst, *this)) EmitToStreamer(*OutStreamer, TmpInst); } @@ -195,6 +213,7 @@ if (TM.getTargetTriple().isOSBinFormatELF()) RTS.finishAttributeSection(); + EmitHwasanMemaccessSymbols(M); } void RISCVAsmPrinter::emitAttributes() { @@ -215,3 +234,253 @@ RegisterAsmPrinter X(getTheRISCV32Target()); RegisterAsmPrinter Y(getTheRISCV64Target()); } + +void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) { + Register Reg = MI.getOperand(0).getReg(); + uint32_t AccessInfo = MI.getOperand(1).getImm(); + MCSymbol *&Sym = + HwasanMemaccessSymbols[HwasanMemaccessTuple(Reg, AccessInfo)]; + if (!Sym) { + // FIXME: Make this work on non-ELF. + if (!TM.getTargetTriple().isOSBinFormatELF()) + report_fatal_error("llvm.hwasan.check.memaccess only supported on ELF"); + + std::string SymName = "__hwasan_check_x" + utostr(Reg - RISCV::X0) + "_" + + utostr(AccessInfo) + "_short"; + Sym = OutContext.getOrCreateSymbol(SymName); + } + auto Res = MCSymbolRefExpr::create(Sym, MCSymbolRefExpr::VK_None, OutContext); + auto Expr = RISCVMCExpr::create(Res, RISCVMCExpr::VK_RISCV_CALL, OutContext); + + EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr)); +} + +void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) { + if (HwasanMemaccessSymbols.empty()) + return; + + const Triple &TT = TM.getTargetTriple(); + assert(TT.isOSBinFormatELF()); + + MCSymbol *HwasanTagMismatchV2Sym = + OutContext.getOrCreateSymbol("__hwasan_tag_mismatch_v2"); + + const MCSymbolRefExpr *HwasanTagMismatchV2Ref = + MCSymbolRefExpr::create(HwasanTagMismatchV2Sym, OutContext); + + for (auto &P : HwasanMemaccessSymbols) { + unsigned Reg = std::get<0>(P.first); + uint32_t AccessInfo = std::get<1>(P.first); + const MCSymbolRefExpr *HwasanTagMismatchRef = HwasanTagMismatchV2Ref; + MCSymbol *Sym = P.second; + + unsigned Size = + 1 << ((AccessInfo >> HWASanAccessInfo::AccessSizeShift) & 0xf); + OutStreamer->switchSection(OutContext.getELFSection( + ".text.hot", ELF::SHT_PROGBITS, + ELF::SHF_EXECINSTR | ELF::SHF_ALLOC | ELF::SHF_GROUP, 0, Sym->getName(), + /*IsComdat=*/true)); + + OutStreamer->emitSymbolAttribute(Sym, MCSA_ELF_TypeFunction); + OutStreamer->emitSymbolAttribute(Sym, MCSA_Weak); + OutStreamer->emitSymbolAttribute(Sym, MCSA_Hidden); + OutStreamer->emitLabel(Sym); + + // Extract shadow offset from ptr + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SLLI).addReg(RISCV::X6).addReg(Reg).addImm(8), + *STI); + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SRLI) + .addReg(RISCV::X6) + .addReg(RISCV::X6) + .addImm(12), + *STI); + // load shadow tag in X6, X5 contains shadow base + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADD) + .addReg(RISCV::X6) + .addReg(RISCV::X5) + .addReg(RISCV::X6), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0), + *STI); + // Extract tag from X5 and compare it with loaded tag from shadow + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SRLI).addReg(RISCV::X7).addReg(Reg).addImm(56), + *STI); + MCSymbol *HandleMismatchOrPartialSym = OutContext.createTempSymbol(); + // X7 contains tag from memory, while X6 contains tag from the pointer + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BNE) + .addReg(RISCV::X7) + .addReg(RISCV::X6) + .addExpr(MCSymbolRefExpr::create(HandleMismatchOrPartialSym, + OutContext)), + *STI); + MCSymbol *ReturnSym = OutContext.createTempSymbol(); + OutStreamer->emitLabel(ReturnSym); + OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X1) + .addImm(0), + *STI); + OutStreamer->emitLabel(HandleMismatchOrPartialSym); + + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X28) + .addReg(RISCV::X0) + .addImm(16), + *STI); + MCSymbol *HandleMismatchSym = OutContext.createTempSymbol(); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BGEU) + .addReg(RISCV::X6) + .addReg(RISCV::X28) + .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)), + *STI); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ANDI).addReg(RISCV::X28).addReg(Reg).addImm(0xF), + *STI); + + if (Size != 1) + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X28) + .addReg(RISCV::X28) + .addImm(Size - 1), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BGE) + .addReg(RISCV::X28) + .addReg(RISCV::X6) + .addExpr(MCSymbolRefExpr::create(HandleMismatchSym, OutContext)), + *STI); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ORI).addReg(RISCV::X6).addReg(Reg).addImm(0xF), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::LBU).addReg(RISCV::X6).addReg(RISCV::X6).addImm(0), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::BEQ) + .addReg(RISCV::X6) + .addReg(RISCV::X7) + .addExpr(MCSymbolRefExpr::create(ReturnSym, OutContext)), + *STI); + + OutStreamer->emitLabel(HandleMismatchSym); + + // | Previous stack frames... | + // +=================================+ <-- [SP + 256] + // | ... | + // | | + // | Stack frame space for x12 - x31.| + // | | + // | ... | + // +---------------------------------+ <-- [SP + 96] + // | Saved x11(arg1), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 88] + // | Saved x10(arg0), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 80] + // | | + // | Stack frame space for x9. | + // +---------------------------------+ <-- [SP + 72] + // | | + // | Saved x8(fp), as | + // | __hwasan_check_* clobbers it. | + // +---------------------------------+ <-- [SP + 64] + // | ... | + // | | + // | Stack frame space for x2 - x7. | + // | | + // | ... | + // +---------------------------------+ <-- [SP + 16] + // | Return address (x1) for caller | + // | of __hwasan_check_*. | + // +---------------------------------+ <-- [SP + 8] + // | Reserved place for x0, possibly | + // | junk, since we don't save it. | + // +---------------------------------+ <-- [x2 / SP] + + // Adjust sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X2) + .addReg(RISCV::X2) + .addImm(-256), + *STI); + + // store x10(arg0) by new sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD) + .addReg(RISCV::X10) + .addReg(RISCV::X2) + .addImm(8 * 10), + *STI); + // store x11(arg1) by new sp + OutStreamer->emitInstruction(MCInstBuilder(RISCV::SD) + .addReg(RISCV::X11) + .addReg(RISCV::X2) + .addImm(8 * 11), + *STI); + + // store x8(fp) by new sp + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SD).addReg(RISCV::X8).addReg(RISCV::X2).addImm(8 * + 8), + *STI); + // store x1(ra) by new sp + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::SD).addReg(RISCV::X1).addReg(RISCV::X2).addImm(1 * + 8), + *STI); + if (Reg != RISCV::X10) + OutStreamer->emitInstruction(MCInstBuilder(RISCV::OR) + .addReg(RISCV::X10) + .addReg(RISCV::X0) + .addReg(Reg), + *STI); + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::ADDI) + .addReg(RISCV::X11) + .addReg(RISCV::X0) + .addImm(AccessInfo & HWASanAccessInfo::RuntimeMask), + *STI); + + // Intentionally load the GOT entry and branch to it, rather than possibly + // late binding the function, which may clobber the registers before we have + // a chance to save them. + RISCVMCExpr::VariantKind VKHi; + unsigned SecondOpcode; + if (OutContext.getObjectFileInfo()->isPositionIndependent()) { + SecondOpcode = RISCV::LD; + VKHi = RISCVMCExpr::VK_RISCV_GOT_HI; + } else { + SecondOpcode = RISCV::ADDI; + VKHi = RISCVMCExpr::VK_RISCV_PCREL_HI; + } + auto ExprHi = RISCVMCExpr::create(HwasanTagMismatchRef, VKHi, OutContext); + + MCSymbol *TmpLabel = + OutContext.createTempSymbol("pcrel_hi", /* AlwaysAddSuffix */ true); + OutStreamer->emitLabel(TmpLabel); + const MCExpr *ExprLo = + RISCVMCExpr::create(MCSymbolRefExpr::create(TmpLabel, OutContext), + RISCVMCExpr::VK_RISCV_PCREL_LO, OutContext); + + OutStreamer->emitInstruction( + MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X6).addExpr(ExprHi), *STI); + OutStreamer->emitInstruction(MCInstBuilder(SecondOpcode) + .addReg(RISCV::X6) + .addReg(RISCV::X6) + .addExpr(ExprLo), + *STI); + + OutStreamer->emitInstruction(MCInstBuilder(RISCV::JALR) + .addReg(RISCV::X0) + .addReg(RISCV::X6) + .addImm(0), + *STI); + } +} diff --git a/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/hwasan-check-memaccess.ll @@ -0,0 +1,60 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 < %s | FileCheck --check-prefixes=CHECK,NOPIC %s +; RUN: llc -mtriple=riscv64 --relocation-model=pic < %s | FileCheck --check-prefixes=CHECK,PIC %s + +define i8* @f2(i8* %x0, i8* %x1) { +; CHECK-LABEL: f2: +; CHECK: # %bb.0: +; CHECK-NEXT: addi sp, sp, -16 +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: sd ra, 8(sp) # 8-byte Folded Spill +; CHECK-NEXT: .cfi_offset ra, -8 +; CHECK-NEXT: mv t0, a1 +; CHECK-NEXT: call __hwasan_check_x10_2_short +; CHECK-NEXT: ld ra, 8(sp) # 8-byte Folded Reload +; CHECK-NEXT: addi sp, sp, 16 +; CHECK-NEXT: ret + call void @llvm.hwasan.check.memaccess.shortgranules(i8* %x1, i8* %x0, i32 2) + ret i8* %x0 +} + +declare void @llvm.hwasan.check.memaccess.shortgranules(i8*, i8*, i32) + +; CHECK: .section .text.hot,"axG",@progbits,__hwasan_check_x10_2_short,comdat +; CHECK-NEXT: .type __hwasan_check_x10_2_short,@function +; CHECK-NEXT: .weak __hwasan_check_x10_2_short +; CHECK-NEXT: .hidden __hwasan_check_x10_2_short +; CHECK-NEXT: __hwasan_check_x10_2_short: +; CHECK-NEXT: slli t1, a0, 8 +; CHECK-NEXT: srli t1, t1, 12 +; CHECK-NEXT: add t1, t0, t1 +; CHECK-NEXT: lbu t1, 0(t1) +; CHECK-NEXT: srli t2, a0, 56 +; CHECK-NEXT: bne t2, t1, .Ltmp0 +; CHECK-NEXT: .Ltmp1: +; CHECK-NEXT: ret +; CHECK-NEXT: .Ltmp0: +; CHECK-NEXT: li t3, 16 +; CHECK-NEXT: bgeu t1, t3, .Ltmp2 +; CHECK-NEXT: andi t3, a0, 15 +; CHECK-NEXT: addi t3, t3, 3 +; CHECK-NEXT: bge t3, t1, .Ltmp2 +; CHECK-NEXT: ori t1, a0, 15 +; CHECK-NEXT: lbu t1, 0(t1) +; CHECK-NEXT: beq t1, t2, .Ltmp1 +; CHECK-NEXT: .Ltmp2: +; CHECK-NEXT: addi sp, sp, -256 +; CHECK-NEXT: sd a0, 80(sp) +; CHECK-NEXT: sd a1, 88(sp) +; CHECK-NEXT: sd s0, 64(sp) +; CHECK-NEXT: sd ra, 8(sp) +; CHECK-NEXT: li a1, 2 +; CHECK-NEXT: .Lpcrel_hi0: +; NOPIC-NEXT: auipc t1, %pcrel_hi(__hwasan_tag_mismatch_v2) +; NOPIC-NEXT: addi t1, t1, %pcrel_lo(.Lpcrel_hi0) +; PIC-NEXT: auipc t1, %got_pcrel_hi(__hwasan_tag_mismatch_v2) +; PIC-NEXT: ld t1, %pcrel_lo(.Lpcrel_hi0)(t1) +; CHECK-NEXT: jr t1 +;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: +; NOPIC: {{.*}} +; PIC: {{.*}}