Index: include/clang/Basic/LangOptions.def =================================================================== --- include/clang/Basic/LangOptions.def +++ include/clang/Basic/LangOptions.def @@ -174,6 +174,7 @@ LANGOPT(CUDADisableTargetCallChecks, 1, 0, "Disable checks for call targets (host, device, etc.)") LANGOPT(CUDATargetOverloads, 1, 0, "Enable function overloads based on CUDA target attributes") LANGOPT(CUDAAllowVariadicFunctions, 1, 0, "Allow variadic functions in CUDA device code") +LANGOPT(CUDAAllowStdComplex, 1, 0, "Allow calls to functions in , other than operator>> and operator<<, from device code") LANGOPT(AssumeSaneOperatorNew , 1, 1, "implicit __attribute__((malloc)) for C++'s new operators") LANGOPT(SizedDeallocation , 1, 0, "enable sized deallocation functions") Index: include/clang/Driver/CC1Options.td =================================================================== --- include/clang/Driver/CC1Options.td +++ include/clang/Driver/CC1Options.td @@ -697,6 +697,8 @@ HelpText<"Enable function overloads based on CUDA target attributes.">; def fcuda_allow_variadic_functions : Flag<["-"], "fcuda-allow-variadic-functions">, HelpText<"Allow variadic functions in CUDA device code.">; +def fcuda_allow_std_complex : Flag<["-"], "fcuda-allow-std-complex">, + HelpText<"Allow calls to functions in , other than operator>> and operator<<, from device code.">; //===----------------------------------------------------------------------===// // OpenMP Options Index: include/clang/Driver/Options.td =================================================================== --- include/clang/Driver/Options.td +++ include/clang/Driver/Options.td @@ -380,6 +380,8 @@ HelpText<"Do host-side CUDA compilation only">; def cuda_noopt_device_debug : Flag<["--"], "cuda-noopt-device-debug">, HelpText<"Enable device-side debug info generation. Disables ptxas optimizations.">; +def cuda_allow_std_complex : Flag<["--"], "cuda-allow-std-complex">, + HelpText<"Allow CUDA device code to use definitions from , other than operator>> and operator<<.">; def cuda_path_EQ : Joined<["--"], "cuda-path=">, Group, HelpText<"CUDA installation path">; def dA : Flag<["-"], "dA">, Group; Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -8914,6 +8914,9 @@ /// (E.2.3.1 in CUDA 7.5 Programming guide). bool isEmptyCudaConstructor(SourceLocation Loc, CXXConstructorDecl *CD); + /// \return true if \p FD should be marked implicitly host+device. + bool declShouldBeCUDAHostDevice(const FunctionDecl &FD); + /// \name Code completion //@{ /// \brief Describes the context in which code completion occurs. Index: lib/Driver/Tools.cpp =================================================================== --- lib/Driver/Tools.cpp +++ lib/Driver/Tools.cpp @@ -3594,6 +3594,8 @@ CmdArgs.push_back(Args.MakeArgString(AuxToolChain->getTriple().str())); CmdArgs.push_back("-fcuda-target-overloads"); CmdArgs.push_back("-fcuda-disable-target-call-checks"); + if (Args.hasArg(options::OPT_cuda_allow_std_complex)) + CmdArgs.push_back("-fcuda-allow-std-complex"); } if (Triple.isOSWindows() && (Triple.getArch() == llvm::Triple::arm || Index: lib/Frontend/CompilerInvocation.cpp =================================================================== --- lib/Frontend/CompilerInvocation.cpp +++ lib/Frontend/CompilerInvocation.cpp @@ -1576,6 +1576,9 @@ if (Args.hasArg(OPT_fcuda_allow_variadic_functions)) Opts.CUDAAllowVariadicFunctions = 1; + if (Args.hasArg(OPT_fcuda_allow_std_complex)) + Opts.CUDAAllowStdComplex = 1; + if (Opts.ObjC1) { if (Arg *arg = Args.getLastArg(OPT_fobjc_runtime_EQ)) { StringRef value = arg->getValue(); Index: lib/Sema/SemaCUDA.cpp =================================================================== --- lib/Sema/SemaCUDA.cpp +++ lib/Sema/SemaCUDA.cpp @@ -14,11 +14,13 @@ #include "clang/Sema/Sema.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclTemplate.h" #include "clang/AST/ExprCXX.h" #include "clang/Lex/Preprocessor.h" #include "clang/Sema/SemaDiagnostic.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" using namespace clang; ExprResult Sema::ActOnCUDAExecConfigExpr(Scope *S, SourceLocation LLLLoc, @@ -450,3 +452,44 @@ return true; } + +// Everything within namespace std inside should be host+device, +// except operator<< and operator>> (ostreams aren't supported in CUDA device +// code). Whitelisting the functions we want, rather than blacklisting the +// stream operators, is a tempting alternative, but libstdc++ uses many helper +// functions, which we'd also have to whitelist. +// +// TODO: Output a better error message if you try to use something from +// without passing -fcuda-allow-std-complex. +// TODO: Output a nvcc-compat warning if you try to use a non-constexpr function +// from -- nvcc only lets you use constexpr functions. +bool Sema::declShouldBeCUDAHostDevice(const FunctionDecl &FD) { + assert(getLangOpts().CUDA); + + if (!getLangOpts().CUDAAllowStdComplex) + return false; + + const SourceManager &SM = getSourceManager(); + SourceLocation Loc = FD.getLocation(); + if (!SM.isInSystemHeader(Loc)) + return false; + const FileEntry *FE = SM.getFileEntryForID(SM.getFileID(Loc)); + if (!FE) + return false; + StringRef Filename = FE->getName(); + if (Filename != "complex" && !Filename.endswith("/complex")) + return false; + + bool IsInStd = FD.isInStdNamespace(); + if (const auto *Method = dyn_cast(&FD)) + if (const auto *Parent = Method->getParent()) + IsInStd |= Parent->isInStdNamespace(); + if (!IsInStd) + return false; + + auto Operator = FD.getOverloadedOperator(); + if (Operator == OO_LessLess || Operator == OO_GreaterGreater) + return false; + + return true; +} Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -8340,6 +8340,12 @@ isExplicitSpecialization || isFunctionTemplateSpecialization); } + // CUDA: Some decls in system headers get an implicit __host__ __device__. + if (getLangOpts().CUDA && declShouldBeCUDAHostDevice(*NewFD)) { + NewFD->addAttr(CUDADeviceAttr::CreateImplicit(Context)); + NewFD->addAttr(CUDAHostAttr::CreateImplicit(Context)); + } + if (getLangOpts().CPlusPlus) { if (FunctionTemplate) { if (NewFD->isInvalidDecl()) Index: test/Driver/cuda-complex.cu =================================================================== --- /dev/null +++ test/Driver/cuda-complex.cu @@ -0,0 +1,15 @@ +// Tests CUDA compilation pipeline construction in Driver. +// REQUIRES: clang-driver + +// Check that --cuda-allow-std-complex passes -fcuda-allow-std-complex to cc1. +// RUN: %clang -### -target x86_64-linux-gnu --cuda-allow-std-complex -c %s 2>&1 \ +// RUN: | FileCheck -check-prefix ALLOW-COMPLEX %s + +// ALLOW-COMPLEX: -fcuda-allow-std-complex + +// But if we don't pass --cuda-allow-std-complex, we don't pass +// -fcuda-allow-std-complex to cc1. +// RUN: %clang -### -target x86_64-linux-gnu -c %s 2>&1 \ +// RUN: | FileCheck -check-prefix NO-ALLOW-COMPLEX %s + +// NO-ALLOW-COMPLEX-NOT: -fcuda-allow-std-complex Index: test/SemaCUDA/Inputs/complex =================================================================== --- /dev/null +++ test/SemaCUDA/Inputs/complex @@ -0,0 +1,30 @@ +// Incomplete stub of used to check that we properly annotate these +// functions as host+device. + +namespace std { + +template +class complex { + public: + complex(const T &re = T(), const T &im = T()); + complex &operator+=(const complex &); + + private: + T real; + T imag; +}; + +template +complex operator+(const complex &, const complex &); + +template +T real(const complex &); + +// Stream operators are not marked as host+device. +template +void operator<<(const complex &, const complex &); + +template +void operator>>(const complex &, const complex &); + +} // namespace std Index: test/SemaCUDA/complex.cu =================================================================== --- /dev/null +++ test/SemaCUDA/complex.cu @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -triple nvptx-unknown-cuda -fsyntax-only -fcuda-allow-std-complex -fcuda-is-device -isystem "%S/Inputs" -verify %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fsyntax-only -fcuda-allow-std-complex -isystem "%S/Inputs" -verify %s + +// Checks that functions inside a system header named are marked as +// host+device. + +#include +#include + +using std::complex; +using std::real; + +void __device__ foo() { + complex x; + complex y(x); + y += x; + x + y; + real(complex(1, 2)); + + // Our header defines complex-to-complex operator<< and operator>>, + // but these are not implicitly marked as host+device. + + x << y; // expected-error {{invalid operands to binary expression}} + // expected-note@complex:* {{call to __host__ function from __device__ function}} + x >> y; // expected-error {{invalid operands to binary expression}} + // expected-note@complex:* {{call to __host__ function from __device__ function}} +}