diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6919,7 +6919,7 @@ def hlsl_entrypoint : Option<["-"], "hlsl-entry", KIND_SEPARATE>, Group, Flags<[CC1Option]>, - MarshallingInfoString>, + MarshallingInfoString, "\"main\"">, HelpText<"Entry point name for hlsl">; def dxc_entrypoint : Option<["--", "/", "-"], "E", KIND_JOINED_OR_SEPARATE>, Group, diff --git a/clang/unittests/Driver/ToolChainTest.cpp b/clang/unittests/Driver/ToolChainTest.cpp --- a/clang/unittests/Driver/ToolChainTest.cpp +++ b/clang/unittests/Driver/ToolChainTest.cpp @@ -14,8 +14,10 @@ #include "clang/Basic/DiagnosticIDs.h" #include "clang/Basic/DiagnosticOptions.h" #include "clang/Basic/LLVM.h" +#include "clang/Basic/TargetOptions.h" #include "clang/Driver/Compilation.h" #include "clang/Driver/Driver.h" +#include "clang/Frontend/CompilerInstance.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Host.h" @@ -571,6 +573,33 @@ DiagConsumer->clear(); } +TEST(DxcModeTest, DefaultEntry) { + IntrusiveRefCntPtr InMemoryFileSystem( + new llvm::vfs::InMemoryFileSystem); + + InMemoryFileSystem->addFile("foo.hlsl", 0, + llvm::MemoryBuffer::getMemBuffer("\n")); + + const char *Args[] = {"clang", "--driver-mode=dxc", "-Tcs_6_7", "foo.hlsl"}; + + IntrusiveRefCntPtr Diags = + CompilerInstance::createDiagnostics(new DiagnosticOptions()); + + CreateInvocationOptions CIOpts; + CIOpts.Diags = Diags; + std::unique_ptr CInvok = + createInvocation(Args, std::move(CIOpts)); + EXPECT_TRUE(CInvok); + // Make sure default entry is "main". + EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "main"); + + const char *EntryArgs[] = {"clang", "--driver-mode=dxc", "-Ebar", "-Tcs_6_7", "foo.hlsl"}; + CInvok = createInvocation(EntryArgs, std::move(CIOpts)); + EXPECT_TRUE(CInvok); + // Make sure "-E" will set entry. + EXPECT_STREQ(CInvok->getTargetOpts().HLSLEntry.c_str(), "bar"); +} + TEST(ToolChainTest, Toolsets) { // Ignore this test on Windows hosts. llvm::Triple Host(llvm::sys::getProcessTriple());