diff --git a/clang/include/clang/Driver/Types.def b/clang/include/clang/Driver/Types.def --- a/clang/include/clang/Driver/Types.def +++ b/clang/include/clang/Driver/Types.def @@ -64,6 +64,9 @@ TYPE("c++-header", CXXHeader, PP_CXXHeader, "hh", phases::Preprocess, phases::Precompile) TYPE("objective-c++-header-cpp-output", PP_ObjCXXHeader, INVALID, "mii", phases::Precompile) TYPE("objective-c++-header", ObjCXXHeader, PP_ObjCXXHeader, "h", phases::Preprocess, phases::Precompile) +TYPE("cuda-header-cpp-output", PP_CUDAHeader, INVALID, "cuhi", phases::Precompile) +TYPE("cuda-header", CUDAHeader, PP_CUDAHeader, "cuh", phases::Preprocess, phases::Precompile) +TYPE("cuda-header", CUDAHeader_DEVICE, PP_CUDAHeader, "cuh", phases::Preprocess, phases::Precompile) TYPE("c++-module", CXXModule, PP_CXXModule, "cppm", phases::Preprocess, phases::Precompile, phases::Compile, phases::Backend, phases::Assemble, phases::Link) TYPE("c++-module-cpp-output", PP_CXXModule, INVALID, "iim", phases::Precompile, phases::Compile, phases::Backend, phases::Assemble, phases::Link) diff --git a/clang/lib/Driver/Driver.cpp b/clang/lib/Driver/Driver.cpp --- a/clang/lib/Driver/Driver.cpp +++ b/clang/lib/Driver/Driver.cpp @@ -2463,6 +2463,7 @@ // If the host input is not CUDA or HIP, we don't need to bother about // this input. if (!(IA->getType() == types::TY_CUDA || + IA->getType() == types::TY_CUDAHeader || IA->getType() == types::TY_HIP || IA->getType() == types::TY_PP_HIP)) { // The builder will ignore this input. @@ -2477,8 +2478,11 @@ return ABRT_Success; // Replicate inputs for each GPU architecture. - auto Ty = IA->getType() == types::TY_HIP ? types::TY_HIP_DEVICE - : types::TY_CUDA_DEVICE; + auto Ty = IA->getType() == types::TY_HIP + ? types::TY_HIP_DEVICE + : (IA->getType() == types::TY_CUDAHeader + ? types::TY_CUDAHeader_DEVICE + : types::TY_CUDA_DEVICE); for (unsigned I = 0, E = GpuArchList.size(); I != E; ++I) { CudaDeviceActions.push_back( C.MakeAction(IA->getInputArg(), Ty)); diff --git a/clang/lib/Driver/Types.cpp b/clang/lib/Driver/Types.cpp --- a/clang/lib/Driver/Types.cpp +++ b/clang/lib/Driver/Types.cpp @@ -97,13 +97,28 @@ bool types::canTypeBeUserSpecified(ID Id) { static const clang::driver::types::ID kStaticLangageTypes[] = { - TY_CUDA_DEVICE, TY_HIP_DEVICE, TY_PP_CHeader, - TY_PP_ObjCHeader, TY_PP_CXXHeader, TY_PP_ObjCXXHeader, - TY_PP_CXXModule, TY_LTO_IR, TY_LTO_BC, - TY_Plist, TY_RewrittenObjC, TY_RewrittenLegacyObjC, - TY_Remap, TY_PCH, TY_Object, - TY_Image, TY_dSYM, TY_Dependencies, - TY_CUDA_FATBIN, TY_HIP_FATBIN}; + TY_CUDA_DEVICE, + TY_CUDAHeader_DEVICE, + TY_HIP_DEVICE, + TY_PP_CHeader, + TY_PP_ObjCHeader, + TY_PP_CXXHeader, + TY_PP_ObjCXXHeader, + TY_PP_CUDAHeader, + TY_PP_CXXModule, + TY_LTO_IR, + TY_LTO_BC, + TY_Plist, + TY_RewrittenObjC, + TY_RewrittenLegacyObjC, + TY_Remap, + TY_PCH, + TY_Object, + TY_Image, + TY_dSYM, + TY_Dependencies, + TY_CUDA_FATBIN, + TY_HIP_FATBIN}; return !llvm::is_contained(kStaticLangageTypes, Id); } @@ -129,6 +144,9 @@ case TY_CL: case TY_CUDA: case TY_PP_CUDA: case TY_CUDA_DEVICE: + case TY_CUDAHeader: + case TY_PP_CUDAHeader: + case TY_CUDAHeader_DEVICE: case TY_HIP: case TY_PP_HIP: case TY_HIP_DEVICE: @@ -171,6 +189,9 @@ case TY_ObjCXXHeader: case TY_PP_ObjCXXHeader: case TY_CXXModule: case TY_PP_CXXModule: case TY_CUDA: case TY_PP_CUDA: case TY_CUDA_DEVICE: + case TY_CUDAHeader: + case TY_PP_CUDAHeader: + case TY_CUDAHeader_DEVICE: case TY_HIP: case TY_PP_HIP: case TY_HIP_DEVICE: @@ -199,6 +220,9 @@ case TY_CUDA: case TY_PP_CUDA: case TY_CUDA_DEVICE: + case TY_CUDAHeader: + case TY_PP_CUDAHeader: + case TY_CUDAHeader_DEVICE: return true; } } @@ -231,65 +255,67 @@ types::ID types::lookupTypeForExtension(llvm::StringRef Ext) { return llvm::StringSwitch(Ext) - .Case("c", TY_C) - .Case("C", TY_CXX) - .Case("F", TY_Fortran) - .Case("f", TY_PP_Fortran) - .Case("h", TY_CHeader) - .Case("H", TY_CXXHeader) - .Case("i", TY_PP_C) - .Case("m", TY_ObjC) - .Case("M", TY_ObjCXX) - .Case("o", TY_Object) - .Case("S", TY_Asm) - .Case("s", TY_PP_Asm) - .Case("bc", TY_LLVM_BC) - .Case("cc", TY_CXX) - .Case("CC", TY_CXX) - .Case("cl", TY_CL) - .Case("cp", TY_CXX) - .Case("cu", TY_CUDA) - .Case("hh", TY_CXXHeader) - .Case("ii", TY_PP_CXX) - .Case("ll", TY_LLVM_IR) - .Case("mi", TY_PP_ObjC) - .Case("mm", TY_ObjCXX) - .Case("rs", TY_RenderScript) - .Case("adb", TY_Ada) - .Case("ads", TY_Ada) - .Case("asm", TY_PP_Asm) - .Case("ast", TY_AST) - .Case("ccm", TY_CXXModule) - .Case("cpp", TY_CXX) - .Case("CPP", TY_CXX) - .Case("c++", TY_CXX) - .Case("C++", TY_CXX) - .Case("cui", TY_PP_CUDA) - .Case("cxx", TY_CXX) - .Case("CXX", TY_CXX) - .Case("F90", TY_Fortran) - .Case("f90", TY_PP_Fortran) - .Case("F95", TY_Fortran) - .Case("f95", TY_PP_Fortran) - .Case("for", TY_PP_Fortran) - .Case("FOR", TY_PP_Fortran) - .Case("fpp", TY_Fortran) - .Case("FPP", TY_Fortran) - .Case("gch", TY_PCH) - .Case("hip", TY_HIP) - .Case("hpp", TY_CXXHeader) - .Case("hxx", TY_CXXHeader) - .Case("iim", TY_PP_CXXModule) - .Case("lib", TY_Object) - .Case("mii", TY_PP_ObjCXX) - .Case("obj", TY_Object) - .Case("ifs", TY_IFS) - .Case("pch", TY_PCH) - .Case("pcm", TY_ModuleFile) - .Case("c++m", TY_CXXModule) - .Case("cppm", TY_CXXModule) - .Case("cxxm", TY_CXXModule) - .Default(TY_INVALID); + .Case("c", TY_C) + .Case("C", TY_CXX) + .Case("F", TY_Fortran) + .Case("f", TY_PP_Fortran) + .Case("h", TY_CHeader) + .Case("H", TY_CXXHeader) + .Case("i", TY_PP_C) + .Case("m", TY_ObjC) + .Case("M", TY_ObjCXX) + .Case("o", TY_Object) + .Case("S", TY_Asm) + .Case("s", TY_PP_Asm) + .Case("bc", TY_LLVM_BC) + .Case("cc", TY_CXX) + .Case("CC", TY_CXX) + .Case("cl", TY_CL) + .Case("cp", TY_CXX) + .Case("cu", TY_CUDA) + .Case("cuh", TY_CUDAHeader) + .Case("hh", TY_CXXHeader) + .Case("ii", TY_PP_CXX) + .Case("ll", TY_LLVM_IR) + .Case("mi", TY_PP_ObjC) + .Case("mm", TY_ObjCXX) + .Case("rs", TY_RenderScript) + .Case("adb", TY_Ada) + .Case("ads", TY_Ada) + .Case("asm", TY_PP_Asm) + .Case("ast", TY_AST) + .Case("ccm", TY_CXXModule) + .Case("cpp", TY_CXX) + .Case("CPP", TY_CXX) + .Case("c++", TY_CXX) + .Case("C++", TY_CXX) + .Case("cui", TY_PP_CUDA) + .Case("cuhi", TY_PP_CUDAHeader) + .Case("cxx", TY_CXX) + .Case("CXX", TY_CXX) + .Case("F90", TY_Fortran) + .Case("f90", TY_PP_Fortran) + .Case("F95", TY_Fortran) + .Case("f95", TY_PP_Fortran) + .Case("for", TY_PP_Fortran) + .Case("FOR", TY_PP_Fortran) + .Case("fpp", TY_Fortran) + .Case("FPP", TY_Fortran) + .Case("gch", TY_PCH) + .Case("hip", TY_HIP) + .Case("hpp", TY_CXXHeader) + .Case("hxx", TY_CXXHeader) + .Case("iim", TY_PP_CXXModule) + .Case("lib", TY_Object) + .Case("mii", TY_PP_ObjCXX) + .Case("obj", TY_Object) + .Case("ifs", TY_IFS) + .Case("pch", TY_PCH) + .Case("pcm", TY_ModuleFile) + .Case("c++m", TY_CXXModule) + .Case("cppm", TY_CXXModule) + .Case("cxxm", TY_CXXModule) + .Default(TY_INVALID); } types::ID types::lookupTypeForTypeSpecifier(const char *Name) { diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp --- a/clang/lib/Sema/SemaCodeComplete.cpp +++ b/clang/lib/Sema/SemaCodeComplete.cpp @@ -9317,6 +9317,7 @@ if (!(Filename.endswith_lower(".h") || Filename.endswith_lower(".hh") || Filename.endswith_lower(".hpp") || + Filename.endswith_lower(".cuh") || Filename.endswith_lower(".inc"))) break; } diff --git a/clang/lib/Tooling/InterpolatingCompilationDatabase.cpp b/clang/lib/Tooling/InterpolatingCompilationDatabase.cpp --- a/clang/lib/Tooling/InterpolatingCompilationDatabase.cpp +++ b/clang/lib/Tooling/InterpolatingCompilationDatabase.cpp @@ -116,6 +116,8 @@ return types::TY_ObjCXX; case types::TY_CUDA: case types::TY_CUDA_DEVICE: + case types::TY_CUDAHeader: + case types::TY_CUDAHeader_DEVICE: return types::TY_CUDA; default: return types::TY_INVALID;