Index: include/clang/Basic/Features.def =================================================================== --- include/clang/Basic/Features.def +++ include/clang/Basic/Features.def @@ -72,7 +72,9 @@ FEATURE(enumerator_attributes, true) FEATURE(nullability, true) FEATURE(nullability_on_arrays, true) -FEATURE(memory_sanitizer, LangOpts.Sanitize.has(SanitizerKind::Memory)) +FEATURE(memory_sanitizer, + LangOpts.Sanitize.hasOneOf(SanitizerKind::Memory | + SanitizerKind::KernelMemory)) FEATURE(thread_sanitizer, LangOpts.Sanitize.has(SanitizerKind::Thread)) FEATURE(dataflow_sanitizer, LangOpts.Sanitize.has(SanitizerKind::DataFlow)) FEATURE(efficiency_sanitizer, Index: include/clang/Basic/Sanitizers.def =================================================================== --- include/clang/Basic/Sanitizers.def +++ include/clang/Basic/Sanitizers.def @@ -53,6 +53,9 @@ // MemorySanitizer SANITIZER("memory", Memory) +// Kernel MemorySanitizer (KMSAN) +SANITIZER("kernel-memory", KernelMemory) + // libFuzzer SANITIZER("fuzzer", Fuzzer) Index: lib/CodeGen/BackendUtil.cpp =================================================================== --- lib/CodeGen/BackendUtil.cpp +++ lib/CodeGen/BackendUtil.cpp @@ -265,14 +265,15 @@ /*CompileKernel*/ true, /*Recover*/ true)); } -static void addMemorySanitizerPass(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { +static void addGeneralOptsForMemorySanitizer(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM, + bool CompileKernel) { const PassManagerBuilderWrapper &BuilderWrapper = static_cast(Builder); const CodeGenOptions &CGOpts = BuilderWrapper.getCGOpts(); int TrackOrigins = CGOpts.SanitizeMemoryTrackOrigins; bool Recover = CGOpts.SanitizeRecover.has(SanitizerKind::Memory); - PM.add(createMemorySanitizerPass(TrackOrigins, Recover)); + PM.add(createMemorySanitizerPass(TrackOrigins, Recover, CompileKernel)); // MemorySanitizer inserts complex instrumentation that mostly follows // the logic of the original code, but operates on "shadow" values. @@ -287,6 +288,16 @@ } } +static void addMemorySanitizerPass(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + addGeneralOptsForMemorySanitizer(Builder, PM, /*CompileKernel*/ false); +} + +static void addKernelMemorySanitizerPass(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + addGeneralOptsForMemorySanitizer(Builder, PM, /*CompileKernel*/ true); +} + static void addThreadSanitizerPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { PM.add(createThreadSanitizerPass()); @@ -613,6 +624,13 @@ addMemorySanitizerPass); } + if (LangOpts.Sanitize.has(SanitizerKind::KernelMemory)) { + PMBuilder.addExtension(PassManagerBuilder::EP_OptimizerLast, + addKernelMemorySanitizerPass); + PMBuilder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0, + addKernelMemorySanitizerPass); + } + if (LangOpts.Sanitize.has(SanitizerKind::Thread)) { PMBuilder.addExtension(PassManagerBuilder::EP_OptimizerLast, addThreadSanitizerPass); Index: lib/CodeGen/CGDeclCXX.cpp =================================================================== --- lib/CodeGen/CGDeclCXX.cpp +++ lib/CodeGen/CGDeclCXX.cpp @@ -347,6 +347,10 @@ !isInSanitizerBlacklist(SanitizerKind::Memory, Fn, Loc)) Fn->addFnAttr(llvm::Attribute::SanitizeMemory); + if (getLangOpts().Sanitize.has(SanitizerKind::KernelMemory) && + !isInSanitizerBlacklist(SanitizerKind::KernelMemory, Fn, Loc)) + Fn->addFnAttr(llvm::Attribute::SanitizeMemory); + if (getLangOpts().Sanitize.has(SanitizerKind::SafeStack) && !isInSanitizerBlacklist(SanitizerKind::SafeStack, Fn, Loc)) Fn->addFnAttr(llvm::Attribute::SafeStack); Index: lib/CodeGen/CodeGenFunction.cpp =================================================================== --- lib/CodeGen/CodeGenFunction.cpp +++ lib/CodeGen/CodeGenFunction.cpp @@ -866,7 +866,7 @@ Fn->addFnAttr(llvm::Attribute::SanitizeHWAddress); if (SanOpts.has(SanitizerKind::Thread)) Fn->addFnAttr(llvm::Attribute::SanitizeThread); - if (SanOpts.has(SanitizerKind::Memory)) + if (SanOpts.hasOneOf(SanitizerKind::Memory | SanitizerKind::KernelMemory)) Fn->addFnAttr(llvm::Attribute::SanitizeMemory); if (SanOpts.has(SanitizerKind::SafeStack)) Fn->addFnAttr(llvm::Attribute::SafeStack); Index: lib/Driver/SanitizerArgs.cpp =================================================================== --- lib/Driver/SanitizerArgs.cpp +++ lib/Driver/SanitizerArgs.cpp @@ -34,8 +34,9 @@ RequiresPIE = DataFlow | HWAddress | Scudo, NeedsUnwindTables = Address | HWAddress | Thread | Memory | DataFlow, SupportsCoverage = Address | HWAddress | KernelAddress | KernelHWAddress | - Memory | Leak | Undefined | Integer | ImplicitConversion | - Nullability | DataFlow | Fuzzer | FuzzerNoLink, + Memory | KernelMemory | Leak | Undefined | Integer | + ImplicitConversion | Nullability | DataFlow | Fuzzer | + FuzzerNoLink, RecoverableByDefault = Undefined | Integer | ImplicitConversion | Nullability, Unrecoverable = Unreachable | Return, AlwaysRecoverable = KernelAddress | KernelHWAddress, @@ -380,8 +381,10 @@ SafeStack), std::make_pair(KernelHWAddress, Address | HWAddress | Leak | Thread | Memory | KernelAddress | Efficiency | - SafeStack | ShadowCallStack)}; - + SafeStack | ShadowCallStack), + std::make_pair(KernelMemory, Address | HWAddress | Leak | Thread | + Memory | KernelAddress | Efficiency | + Scudo | SafeStack)}; // Enable toolchain specific default sanitizers if not explicitly disabled. SanitizerMask Default = TC.getDefaultSanitizers() & ~AllRemove; Index: lib/Driver/ToolChains/Linux.cpp =================================================================== --- lib/Driver/ToolChains/Linux.cpp +++ lib/Driver/ToolChains/Linux.cpp @@ -934,6 +934,8 @@ Res |= SanitizerKind::Leak; if (IsX86_64 || IsMIPS64 || IsAArch64 || IsPowerPC64) Res |= SanitizerKind::Thread; + if (IsX86_64) + Res |= SanitizerKind::KernelMemory; if (IsX86_64 || IsMIPS64) Res |= SanitizerKind::Efficiency; if (IsX86 || IsX86_64) Index: lib/Frontend/CompilerInvocation.cpp =================================================================== --- lib/Frontend/CompilerInvocation.cpp +++ lib/Frontend/CompilerInvocation.cpp @@ -3064,7 +3064,9 @@ // names. Res.getCodeGenOpts().DiscardValueNames &= !LangOpts.Sanitize.has(SanitizerKind::Address) && - !LangOpts.Sanitize.has(SanitizerKind::Memory); + !LangOpts.Sanitize.has(SanitizerKind::KernelAddress) && + !LangOpts.Sanitize.has(SanitizerKind::Memory) && + !LangOpts.Sanitize.has(SanitizerKind::KernelMemory); ParsePreprocessorArgs(Res.getPreprocessorOpts(), Args, Diags, Res.getFrontendOpts().ProgramAction);