diff --git a/clang/lib/Driver/Driver.cpp b/clang/lib/Driver/Driver.cpp --- a/clang/lib/Driver/Driver.cpp +++ b/clang/lib/Driver/Driver.cpp @@ -1240,11 +1240,23 @@ T.setObjectFormat(llvm::Triple::COFF); TargetTriple = T.str(); } else if (IsDXCMode()) { - // clang-dxc target is build from target_profile option. - // Just set OS to shader model to select HLSLToolChain. + // Set OS to shader model to select HLSLToolChain. llvm::Triple T(TargetTriple); T.setOS(llvm::Triple::ShaderModel); TargetTriple = T.str(); + + // Build TargetTriple from target_profile option for clang-dxc. + if (const Arg *A = Args.getLastArg(options::OPT_target_profile)) { + StringRef TargetProfile = A->getValue(); + std::string Triple = + toolchains::HLSLToolChain::parseTargetProfile(TargetProfile); + if (Triple == "") { + Diag(diag::err_drv_invalid_directx_shader_module) << TargetProfile; + } else { + TargetTriple = Triple; + } + A->claim(); + } } if (const Arg *A = Args.getLastArg(options::OPT_target)) diff --git a/clang/lib/Driver/ToolChains/HLSL.h b/clang/lib/Driver/ToolChains/HLSL.h --- a/clang/lib/Driver/ToolChains/HLSL.h +++ b/clang/lib/Driver/ToolChains/HLSL.h @@ -29,8 +29,7 @@ llvm::opt::DerivedArgList * TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch, Action::OffloadKind DeviceOffloadKind) const override; - std::string ComputeEffectiveClangTriple(const llvm::opt::ArgList &Args, - types::ID InputType) const override; + static std::string parseTargetProfile(StringRef TargetProfile); }; } // end namespace toolchains diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp --- a/clang/lib/Driver/ToolChains/HLSL.cpp +++ b/clang/lib/Driver/ToolChains/HLSL.cpp @@ -138,21 +138,9 @@ const ArgList &Args) : ToolChain(D, Triple, Args) {} -std::string -HLSLToolChain::ComputeEffectiveClangTriple(const ArgList &Args, - types::ID InputType) const { - if (Arg *A = Args.getLastArg(options::OPT_target_profile)) { - StringRef Profile = A->getValue(); - std::string Triple = tryParseProfile(Profile); - if (Triple == "") { - getDriver().Diag(diag::err_drv_invalid_directx_shader_module) << Profile; - Triple = ToolChain::ComputeEffectiveClangTriple(Args, InputType); - } - A->claim(); - return Triple; - } else { - return ToolChain::ComputeEffectiveClangTriple(Args, InputType); - } +std::string clang::driver::toolchains::HLSLToolChain::parseTargetProfile( + StringRef TargetProfile) { + return tryParseProfile(TargetProfile); } DerivedArgList * diff --git a/clang/unittests/Driver/ToolChainTest.cpp b/clang/unittests/Driver/ToolChainTest.cpp --- a/clang/unittests/Driver/ToolChainTest.cpp +++ b/clang/unittests/Driver/ToolChainTest.cpp @@ -387,6 +387,28 @@ std::vector> Errors; }; +static void validateTargetProfile(StringRef TargetProfile, + StringRef ExpectTriple, Driver &TheDriver, + DiagnosticsEngine &Diags) { + EXPECT_TRUE(TheDriver.BuildCompilation( + {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl"})); + EXPECT_STREQ(TheDriver.getTargetTriple().c_str(), ExpectTriple.data()); + EXPECT_EQ(Diags.getNumErrors(), 0u); +} + +static void validateTargetProfile(StringRef TargetProfile, + StringRef ExpectError, Driver &TheDriver, + DiagnosticsEngine &Diags, + SimpleDiagnosticConsumer *DiagConsumer, + unsigned NumOfErrors) { + EXPECT_TRUE(TheDriver.BuildCompilation( + {"clang", "--driver-mode=dxc", TargetProfile.data(), "foo.hlsl"})); + EXPECT_EQ(Diags.getNumErrors(), NumOfErrors); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), ExpectError.data()); + Diags.Clear(); + DiagConsumer->clear(); +} + TEST(DxcModeTest, TargetProfileValidation) { IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); @@ -400,111 +422,38 @@ IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); - std::unique_ptr C( - TheDriver.BuildCompilation({"clang", "--driver-mode=dxc", "foo.hlsl"})); - EXPECT_TRUE(C); - EXPECT_TRUE(!C->containsError()); - - auto &TC = C->getDefaultToolChain(); - bool ContainsError = false; - auto Args = TheDriver.ParseArgStrings({"-Tvs_6_0"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - auto Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.0-vertex"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - Args = TheDriver.ParseArgStrings({"-Ths_6_1"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.1-hull"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tds_6_2"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.2-domain"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tds_6_2"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.2-domain"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tgs_6_3"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.3-geometry"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tps_6_4"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.4-pixel"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tcs_6_5"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.5-compute"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tms_6_6"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.6-mesh"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tas_6_7"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.7-amplification"); - EXPECT_EQ(Diags.getNumErrors(), 0u); - - Args = TheDriver.ParseArgStrings({"-Tlib_6_x"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "dxil--shadermodel6.15-library"); - EXPECT_EQ(Diags.getNumErrors(), 0u); + validateTargetProfile("-Tvs_6_0", "dxil--shadermodel6.0-vertex", TheDriver, + Diags); + validateTargetProfile("-Ths_6_1", "dxil--shadermodel6.1-hull", TheDriver, + Diags); + validateTargetProfile("-Tds_6_2", "dxil--shadermodel6.2-domain", TheDriver, + Diags); + validateTargetProfile("-Tds_6_2", "dxil--shadermodel6.2-domain", TheDriver, + Diags); + validateTargetProfile("-Tgs_6_3", "dxil--shadermodel6.3-geometry", TheDriver, + Diags); + validateTargetProfile("-Tps_6_4", "dxil--shadermodel6.4-pixel", TheDriver, + Diags); + validateTargetProfile("-Tcs_6_5", "dxil--shadermodel6.5-compute", TheDriver, + Diags); + validateTargetProfile("-Tms_6_6", "dxil--shadermodel6.6-mesh", TheDriver, + Diags); + validateTargetProfile("-Tas_6_7", "dxil--shadermodel6.7-amplification", + TheDriver, Diags); + validateTargetProfile("-Tlib_6_x", "dxil--shadermodel6.15-library", TheDriver, + Diags); // Invalid tests. - Args = TheDriver.ParseArgStrings({"-Tpss_6_1"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); - EXPECT_EQ(Diags.getNumErrors(), 1u); - EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), - "invalid profile : pss_6_1"); - Diags.Clear(); - DiagConsumer->clear(); - - Args = TheDriver.ParseArgStrings({"-Tps_6_x"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); - EXPECT_EQ(Diags.getNumErrors(), 2u); - EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : ps_6_x"); - Diags.Clear(); - DiagConsumer->clear(); - - Args = TheDriver.ParseArgStrings({"-Tlib_6_1"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); - EXPECT_EQ(Diags.getNumErrors(), 3u); - EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), - "invalid profile : lib_6_1"); - Diags.Clear(); - DiagConsumer->clear(); - - Args = TheDriver.ParseArgStrings({"-Tfoo"}, false, ContainsError); - EXPECT_FALSE(ContainsError); - Triple = TC.ComputeEffectiveClangTriple(Args); - EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); - EXPECT_EQ(Diags.getNumErrors(), 4u); - EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : foo"); - Diags.Clear(); - DiagConsumer->clear(); + validateTargetProfile("-Tpss_6_1", "invalid profile : pss_6_1", TheDriver, + Diags, DiagConsumer, 1); + + validateTargetProfile("-Tps_6_x", "invalid profile : ps_6_x", TheDriver, + Diags, DiagConsumer, 2); + validateTargetProfile("-Tlib_6_1", "invalid profile : lib_6_1", TheDriver, + Diags, DiagConsumer, 3); + validateTargetProfile("-Tfoo", "invalid profile : foo", TheDriver, Diags, + DiagConsumer, 4); } TEST(DxcModeTest, ValidatorVersionValidation) {