diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -280,6 +280,13 @@ void sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); private: + struct ShadowTagCheckInfo { + Instruction *TagMismatchTerm = nullptr; + Value *PtrLong = nullptr; + Value *AddrLong = nullptr; + Value *PtrTag = nullptr; + Value *MemTag = nullptr; + }; void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } void initializeModule(); @@ -296,6 +303,8 @@ Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); int64_t getAccessInfo(bool IsWrite, unsigned AccessSizeIndex); + ShadowTagCheckInfo insertShadowTagCheck(Value *Ptr, + Instruction *InsertBefore); void instrumentMemAccessOutline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore); @@ -851,17 +860,45 @@ (AccessSizeIndex << HWASanAccessInfo::AccessSizeShift); } +HWAddressSanitizer::ShadowTagCheckInfo +HWAddressSanitizer::insertShadowTagCheck(Value *Ptr, + Instruction *InsertBefore) { + ShadowTagCheckInfo R; + + IRBuilder<> IRB(InsertBefore); + + R.PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); + R.PtrTag = + IRB.CreateTrunc(IRB.CreateLShr(R.PtrLong, PointerTagShift), Int8Ty); + R.AddrLong = untagPointer(IRB, R.PtrLong); + Value *Shadow = memToShadow(R.AddrLong, IRB); + R.MemTag = IRB.CreateLoad(Int8Ty, Shadow); + Value *TagMismatch = IRB.CreateICmpNE(R.PtrTag, R.MemTag); + + if (MatchAllTag.has_value()) { + Value *TagNotIgnored = IRB.CreateICmpNE( + R.PtrTag, ConstantInt::get(R.PtrTag->getType(), *MatchAllTag)); + TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); + } + + R.TagMismatchTerm = + SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false, + MDBuilder(*C).createBranchWeights(1, 100000)); + + return R; +} + void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore) { assert(!UsePageAliases); const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); - IRBuilder<> IRB(InsertBefore); if (InlineFastPath) { // TODO. } + IRBuilder<> IRB(InsertBefore); Module *M = IRB.GetInsertBlock()->getParent()->getParent(); Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); IRB.CreateCall(Intrinsic::getDeclaration( @@ -876,49 +913,32 @@ Instruction *InsertBefore) { assert(!UsePageAliases); const int64_t AccessInfo = getAccessInfo(IsWrite, AccessSizeIndex); - IRBuilder<> IRB(InsertBefore); - - Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); - Value *PtrTag = - IRB.CreateTrunc(IRB.CreateLShr(PtrLong, PointerTagShift), Int8Ty); - Value *AddrLong = untagPointer(IRB, PtrLong); - Value *Shadow = memToShadow(AddrLong, IRB); - Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow); - Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); - if (MatchAllTag.has_value()) { - Value *TagNotIgnored = IRB.CreateICmpNE( - PtrTag, ConstantInt::get(PtrTag->getType(), *MatchAllTag)); - TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); - } - - Instruction *TagMismatchTerm = - SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false, - MDBuilder(*C).createBranchWeights(1, 100000)); + ShadowTagCheckInfo TCI = insertShadowTagCheck(Ptr, InsertBefore); - IRB.SetInsertPoint(TagMismatchTerm); + IRBuilder<> IRB(TCI.TagMismatchTerm); Value *OutOfShortGranuleTagRange = - IRB.CreateICmpUGT(MemTag, ConstantInt::get(Int8Ty, 15)); + IRB.CreateICmpUGT(TCI.MemTag, ConstantInt::get(Int8Ty, 15)); Instruction *CheckFailTerm = SplitBlockAndInsertIfThen( - OutOfShortGranuleTagRange, TagMismatchTerm, !Recover, + OutOfShortGranuleTagRange, TCI.TagMismatchTerm, !Recover, MDBuilder(*C).createBranchWeights(1, 100000)); - IRB.SetInsertPoint(TagMismatchTerm); - Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(PtrLong, 15), Int8Ty); + IRB.SetInsertPoint(TCI.TagMismatchTerm); + Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(TCI.PtrLong, 15), Int8Ty); PtrLowBits = IRB.CreateAdd( PtrLowBits, ConstantInt::get(Int8Ty, (1 << AccessSizeIndex) - 1)); - Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, MemTag); - SplitBlockAndInsertIfThen(PtrLowBitsOOB, TagMismatchTerm, false, + Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, TCI.MemTag); + SplitBlockAndInsertIfThen(PtrLowBitsOOB, TCI.TagMismatchTerm, false, MDBuilder(*C).createBranchWeights(1, 100000), (DomTreeUpdater *)nullptr, nullptr, CheckFailTerm->getParent()); - IRB.SetInsertPoint(TagMismatchTerm); - Value *InlineTagAddr = IRB.CreateOr(AddrLong, 15); + IRB.SetInsertPoint(TCI.TagMismatchTerm); + Value *InlineTagAddr = IRB.CreateOr(TCI.AddrLong, 15); InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, Int8PtrTy); Value *InlineTag = IRB.CreateLoad(Int8Ty, InlineTagAddr); - Value *InlineTagMismatch = IRB.CreateICmpNE(PtrTag, InlineTag); - SplitBlockAndInsertIfThen(InlineTagMismatch, TagMismatchTerm, false, + Value *InlineTagMismatch = IRB.CreateICmpNE(TCI.PtrTag, InlineTag); + SplitBlockAndInsertIfThen(InlineTagMismatch, TCI.TagMismatchTerm, false, MDBuilder(*C).createBranchWeights(1, 100000), (DomTreeUpdater *)nullptr, nullptr, CheckFailTerm->getParent()); @@ -929,7 +949,7 @@ case Triple::x86_64: // The signal handler will find the data address in rdi. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "int3\nnopl " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)) + "(%rax)", @@ -940,7 +960,7 @@ case Triple::aarch64_be: // The signal handler will find the data address in x0. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "brk #" + itostr(0x900 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x0}", /*hasSideEffects=*/true); @@ -948,7 +968,7 @@ case Triple::riscv64: // The signal handler will find the data address in x10. Asm = InlineAsm::get( - FunctionType::get(VoidTy, {PtrLong->getType()}, false), + FunctionType::get(VoidTy, {TCI.PtrLong->getType()}, false), "ebreak\naddiw x0, x11, " + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), "{x10}", @@ -957,10 +977,10 @@ default: report_fatal_error("unsupported architecture"); } - IRB.CreateCall(Asm, PtrLong); + IRB.CreateCall(Asm, TCI.PtrLong); if (Recover) cast(CheckFailTerm) - ->setSuccessor(0, TagMismatchTerm->getParent()); + ->setSuccessor(0, TCI.TagMismatchTerm->getParent()); } bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) {