diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -623,24 +623,33 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); - const std::string EndingStr = Recover ? "_noabort" : ""; const std::string MatchAllStr = UseMatchAllCallback ? "_match_all" : ""; + FunctionType *HwasanMemoryAccessCallbackSizedFnTy, + *HwasanMemoryAccessCallbackFnTy, *HWAsanMemTransferFnTy, + *HWAsanMemsetFnTy; + if (UseMatchAllCallback) { + HwasanMemoryAccessCallbackSizedFnTy = + FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false); + HwasanMemoryAccessCallbackFnTy = + FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false); + HWAsanMemTransferFnTy = FunctionType::get( + Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy, Int8Ty}, false); + HWAsanMemsetFnTy = FunctionType::get( + Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}, false); + } else { + HwasanMemoryAccessCallbackSizedFnTy = + FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false); + HwasanMemoryAccessCallbackFnTy = + FunctionType::get(VoidTy, {IntptrTy}, false); + HWAsanMemTransferFnTy = + FunctionType::get(Int8PtrTy, {Int8PtrTy, Int8PtrTy, IntptrTy}, false); + HWAsanMemsetFnTy = + FunctionType::get(Int8PtrTy, {Int8PtrTy, Int32Ty, IntptrTy}, false); + } + for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; - - FunctionType *HwasanMemoryAccessCallbackSizedFnTy, - *HwasanMemoryAccessCallbackFnTy; - if (UseMatchAllCallback) { - HwasanMemoryAccessCallbackSizedFnTy = - FunctionType::get(VoidTy, {IntptrTy, IntptrTy, Int8Ty}, false); - HwasanMemoryAccessCallbackFnTy = - FunctionType::get(VoidTy, {IntptrTy, Int8Ty}, false); - } else { - HwasanMemoryAccessCallbackSizedFnTy = - FunctionType::get(VoidTy, {IntptrTy, IntptrTy}, false); - HwasanMemoryAccessCallbackFnTy = - FunctionType::get(VoidTy, {IntptrTy}, false); - } + const std::string EndingStr = Recover ? "_noabort" : ""; HwasanMemoryAccessCallbackSized[AccessIsWrite] = M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + TypeStr + "N" + MatchAllStr + EndingStr, @@ -656,33 +665,11 @@ } } - HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy, - Int8PtrTy, Int8Ty, IntptrTy); - HwasanGenerateTagFunc = - M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty); - - HwasanRecordFrameRecordFunc = - M.getOrInsertFunction("__hwasan_add_frame_record", VoidTy, Int64Ty); - - ShadowGlobal = - M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(Int8Ty, 0)); - const std::string MemIntrinCallbackPrefix = (CompileKernel && !ClKasanMemIntrinCallbackPrefix) ? std::string("") : ClMemoryAccessCallbackPrefix; - const size_t HWAsanMemTransferArgSize = UseMatchAllCallback ? 4 : 3; - llvm::Type *HWAsanMemTransferArgTys[] = {Int8PtrTy, Int8PtrTy, IntptrTy, - Int8Ty}; - FunctionType *HWAsanMemTransferFnTy = FunctionType::get( - Int8PtrTy, - ArrayRef(&HWAsanMemTransferArgTys[0], HWAsanMemTransferArgSize), false); - const size_t HWAsanMemsetArgSize = UseMatchAllCallback ? 4 : 3; - llvm::Type *HWAsanMemsetArgTys[] = {Int8PtrTy, Int32Ty, IntptrTy, Int8Ty}; - FunctionType *HWAsanMemsetFnTy = FunctionType::get( - Int8PtrTy, ArrayRef(&HWAsanMemsetArgTys[0], HWAsanMemsetArgSize), false); - HWAsanMemmove = M.getOrInsertFunction( MemIntrinCallbackPrefix + "memmove" + MatchAllStr, HWAsanMemTransferFnTy); HWAsanMemcpy = M.getOrInsertFunction( @@ -690,6 +677,17 @@ HWAsanMemset = M.getOrInsertFunction( MemIntrinCallbackPrefix + "memset" + MatchAllStr, HWAsanMemsetFnTy); + HwasanTagMemoryFunc = M.getOrInsertFunction("__hwasan_tag_memory", VoidTy, + Int8PtrTy, Int8Ty, IntptrTy); + HwasanGenerateTagFunc = + M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty); + + HwasanRecordFrameRecordFunc = + M.getOrInsertFunction("__hwasan_add_frame_record", VoidTy, Int64Ty); + + ShadowGlobal = + M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(Int8Ty, 0)); + HWAsanHandleVfork = M.getOrInsertFunction("__hwasan_handle_vfork", VoidTy, IntptrTy); }