diff --git a/compiler-rt/lib/hwasan/hwasan_interface_internal.h b/compiler-rt/lib/hwasan/hwasan_interface_internal.h --- a/compiler-rt/lib/hwasan/hwasan_interface_internal.h +++ b/compiler-rt/lib/hwasan/hwasan_interface_internal.h @@ -235,6 +235,13 @@ SANITIZER_INTERFACE_ATTRIBUTE void *__hwasan_memmove(void *dest, const void *src, uptr n); +SANITIZER_INTERFACE_ATTRIBUTE +void *__hwasan_memcpy_match_all(void *dst, const void *src, uptr size, u8); +SANITIZER_INTERFACE_ATTRIBUTE +void *__hwasan_memset_match_all(void *s, int c, uptr n, u8); +SANITIZER_INTERFACE_ATTRIBUTE +void *__hwasan_memmove_match_all(void *dest, const void *src, uptr n, u8); + SANITIZER_INTERFACE_ATTRIBUTE void __hwasan_set_error_report_callback(void (*callback)(const char *)); } // extern "C" diff --git a/compiler-rt/lib/hwasan/hwasan_memintrinsics.cpp b/compiler-rt/lib/hwasan/hwasan_memintrinsics.cpp --- a/compiler-rt/lib/hwasan/hwasan_memintrinsics.cpp +++ b/compiler-rt/lib/hwasan/hwasan_memintrinsics.cpp @@ -42,3 +42,33 @@ reinterpret_cast(from), size); return memmove(to, from, size); } + +void *__hwasan_memset_match_all(void *block, int c, uptr size, + u8 match_all_tag) { + if (GetTagFromPointer(reinterpret_cast(block)) != match_all_tag) + CheckAddressSized( + reinterpret_cast(block), size); + return memset(block, c, size); +} + +void *__hwasan_memcpy_match_all(void *to, const void *from, uptr size, + u8 match_all_tag) { + if (GetTagFromPointer(reinterpret_cast(to)) != match_all_tag) + CheckAddressSized( + reinterpret_cast(to), size); + if (GetTagFromPointer(reinterpret_cast(from)) != match_all_tag) + CheckAddressSized( + reinterpret_cast(from), size); + return memcpy(to, from, size); +} + +void *__hwasan_memmove_match_all(void *to, const void *from, uptr size, + u8 match_all_tag) { + if (GetTagFromPointer(reinterpret_cast(to)) != match_all_tag) + CheckAddressSized( + reinterpret_cast(to), size); + if (GetTagFromPointer(reinterpret_cast(from)) != match_all_tag) + CheckAddressSized( + reinterpret_cast(from), size); + return memmove(to, from, size); +} diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -623,10 +623,10 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); + const std::string EndingStr = Recover ? "_noabort" : ""; + const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : ""; for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; - const std::string EndingStr = Recover ? "_noabort" : ""; - const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : ""; const size_t CallbackSizedArgSize = UseMatchAllCallback ? 3 : 2; llvm::Type *CallbackSizedArgTys[] = {IntptrTy, IntptrTy, Int8Ty}; @@ -665,14 +665,24 @@ (CompileKernel && !ClKasanMemIntrinCallbackPrefix) ? std::string("") : ClMemoryAccessCallbackPrefix; - HWAsanMemmove = - M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", Int8PtrTy, - Int8PtrTy, Int8PtrTy, IntptrTy); - HWAsanMemcpy = - M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", Int8PtrTy, - Int8PtrTy, Int8PtrTy, IntptrTy); - HWAsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", - Int8PtrTy, Int8PtrTy, Int32Ty, IntptrTy); + + const size_t HWAsanMemTransferArgSize = UseMatchAllCallback ? 4 : 3; + llvm::Type *HWAsanMemTransferArgTys[] = {Int8PtrTy, Int8PtrTy, IntptrTy, + Int8Ty}; + FunctionType *HWAsanMemTransferFnTy = FunctionType::get( + Int8PtrTy, + ArrayRef(&HWAsanMemTransferArgTys[0], HWAsanMemTransferArgSize), false); + const size_t HWAsanMemsetArgSize = UseMatchAllCallback ? 4 : 3; + llvm::Type *HWAsanMemsetArgTys[] = {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}; + FunctionType *HWAsanMemsetFnTy = FunctionType::get( + Int8PtrTy, ArrayRef(&HWAsanMemsetArgTys[0], HWAsanMemsetArgSize), false); + + HWAsanMemmove = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memmove" + MatchAllStr, HWAsanMemTransferFnTy); + HWAsanMemcpy = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memcpy" + MatchAllStr, HWAsanMemTransferFnTy); + HWAsanMemset = M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memset" + MatchAllStr, HWAsanMemsetFnTy); HWAsanHandleVfork = M.getOrInsertFunction("__hwasan_handle_vfork", VoidTy, IntptrTy); @@ -943,15 +953,35 @@ void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { IRBuilder<> IRB(MI); if (isa(MI)) { - IRB.CreateCall(isa(MI) ? HWAsanMemmove : HWAsanMemcpy, - {IRB.CreatePointerCast(MI->getOperand(0), Int8PtrTy), - IRB.CreatePointerCast(MI->getOperand(1), Int8PtrTy), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + if (UseMatchAllCallback) { + IRB.CreateCall( + isa(MI) ? HWAsanMemmove : HWAsanMemcpy, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall( + isa(MI) ? HWAsanMemmove : HWAsanMemcpy, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } } else if (isa(MI)) { - IRB.CreateCall(HWAsanMemset, - {IRB.CreatePointerCast(MI->getOperand(0), Int8PtrTy), - IRB.CreateIntCast(MI->getOperand(1), Int32Ty, false), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + if (UseMatchAllCallback) { + IRB.CreateCall( + HWAsanMemset, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false), + ConstantInt::get(Int8Ty, *MatchAllTag)}); + } else { + IRB.CreateCall( + HWAsanMemset, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } } MI->eraseFromParent(); } diff --git a/llvm/test/Instrumentation/HWAddressSanitizer/mem-intrinsics.ll b/llvm/test/Instrumentation/HWAddressSanitizer/mem-intrinsics.ll --- a/llvm/test/Instrumentation/HWAddressSanitizer/mem-intrinsics.ll +++ b/llvm/test/Instrumentation/HWAddressSanitizer/mem-intrinsics.ll @@ -1,6 +1,7 @@ ; RUN: opt -S -passes=hwasan -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-PREFIX %s ; RUN: opt -S -passes=hwasan -hwasan-kernel -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-NOPREFIX %s ; RUN: opt -S -passes=hwasan -hwasan-kernel -hwasan-kernel-mem-intrinsic-prefix -hwasan-use-stack-safety=0 %s | FileCheck --check-prefixes=CHECK,CHECK-PREFIX %s +; RUN: opt -S -passes=hwasan -hwasan-use-stack-safety=0 -hwasan-match-all-tag=0 %s | FileCheck --check-prefixes=CHECK,CHECK-MATCH-ALL-TAG %s target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" @@ -15,19 +16,22 @@ store i32 0, ptr %retval, align 4 call void @llvm.memset.p0.i64(ptr align 1 %Q, i8 0, i64 10, i1 false) -; CHECK-PREFIX: call ptr @__hwasan_memset -; CHECK-NOPREFIX: call ptr @memset +; CHECK-PREFIX: call ptr @__hwasan_memset( +; CHECK-NOPREFIX: call ptr @memset( +; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memset_match_all(ptr %Q.hwasan, i32 0, i64 10, i8 0) %add.ptr = getelementptr inbounds i8, ptr %Q, i64 5 call void @llvm.memmove.p0.p0.i64(ptr align 1 %Q, ptr align 1 %add.ptr, i64 5, i1 false) -; CHECK-PREFIX: call ptr @__hwasan_memmove -; CHECK-NOPREFIX: call ptr @memmove +; CHECK-PREFIX: call ptr @__hwasan_memmove( +; CHECK-NOPREFIX: call ptr @memmove( +; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memmove_match_all(ptr %Q.hwasan, ptr %add.ptr, i64 5, i8 0) call void @llvm.memcpy.p0.p0.i64(ptr align 1 %P, ptr align 1 %Q, i64 10, i1 false) -; CHECK-PREFIX: call ptr @__hwasan_memcpy -; CHECK-NOPREFIX: call ptr @memcpy +; CHECK-PREFIX: call ptr @__hwasan_memcpy( +; CHECK-NOPREFIX: call ptr @memcpy( +; CHECK-MATCH-ALL-TAG: call ptr @__hwasan_memcpy_match_all(ptr %P.hwasan, ptr %Q.hwasan, i64 10, i8 0) ret i32 0 }