diff --git a/clang/include/clang/Basic/DiagnosticDriverKinds.td b/clang/include/clang/Basic/DiagnosticDriverKinds.td --- a/clang/include/clang/Basic/DiagnosticDriverKinds.td +++ b/clang/include/clang/Basic/DiagnosticDriverKinds.td @@ -660,4 +660,13 @@ def err_drv_invalid_directx_shader_module : Error< "invalid profile : %0">; +def err_drv_invalid_range_dxil_validator_version : Error< + "invalid validator version : %0\n" + "Validator version must be less than or equal to current internal version.">; +def err_drv_invalid_format_dxil_validator_version : Error< + "invalid validator version : %0\n" + "Format of validator version is \".\" (ex:\"1.4\").">; +def err_drv_invalid_empty_dxil_validator_version : Error< + "invalid validator version : %0\n" + "If validator major version is 0, minor version must also be 0.">; } diff --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h --- a/clang/include/clang/Basic/TargetOptions.h +++ b/clang/include/clang/Basic/TargetOptions.h @@ -110,8 +110,11 @@ /// The version of the darwin target variant SDK which was used during the /// compilation. llvm::VersionTuple DarwinTargetVariantSDKVersion; + + /// The validator version for dxil. + std::string DxilValidatorVersion; }; -} // end namespace clang +} // end namespace clang #endif diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -6697,20 +6697,21 @@ def dxc_Group : OptionGroup<"">, Flags<[DXCOption]>, HelpText<"dxc compatibility options">; - class DXCJoinedOrSeparate : Option<["/", "-"], name, KIND_JOINED_OR_SEPARATE>, Group, Flags<[DXCOption, NoXarchOption]>; def dxc_help : Option<["/", "-", "--"], "help", KIND_JOINED>, Group, Flags<[DXCOption, NoXarchOption]>, Alias, HelpText<"Display available options">; - - def Fo : DXCJoinedOrSeparate<"Fo">, Alias, - HelpText<"Output object file.">; - + HelpText<"Output object file">; +def dxil_validator_version : Option<["/", "-"], "validator-version", KIND_SEPARATE>, + Group, Flags<[DXCOption, NoXarchOption, CC1Option, HelpHidden]>, + HelpText<"Override validator version for module. Format: ;" + "Default: DXIL.dll version or current internal version">, + MarshallingInfoString>; def target_profile : DXCJoinedOrSeparate<"T">, MetaVarName<"">, - HelpText<"Set target profile.">, + HelpText<"Set target profile">, Values<"ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, ps_6_6, ps_6_7," "vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, vs_6_6, vs_6_7," "gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, gs_6_6, gs_6_7," diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h new file mode 100644 --- /dev/null +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -0,0 +1,38 @@ +//===----- CGHLSLRuntime.h - Interface to HLSL Runtimes -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for HLSL code generation. Concrete +// subclasses of this implement code generation for specific HLSL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H +#define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H + +namespace clang { + +namespace CodeGen { + +class CodeGenModule; + +class CGHLSLRuntime { +protected: + CodeGenModule &CGM; + +public: + CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} + virtual ~CGHLSLRuntime() {} + + void finishCodeGen(); +}; + +} // namespace CodeGen +} // namespace clang + +#endif diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -0,0 +1,52 @@ +//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for HLSL code generation. Concrete +// subclasses of this implement code generation for specific HLSL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#include "CGHLSLRuntime.h" +#include "CodeGenModule.h" +#include "clang/Basic/TargetOptions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" + +using namespace clang; +using namespace CodeGen; +using namespace llvm; + +namespace { +void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { + // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. + // Assume ValVersionStr is legal here. + VersionTuple Version; + if (Version.tryParse(ValVersionStr) || Version.getBuild() || + Version.getSubminor() || !Version.getMinor()) { + return; + } + + uint64_t Major = Version.getMajor(); + uint64_t Minor = Version.getMinor().getValue(); + + auto &Ctx = M.getContext(); + IRBuilder<> B(M.getContext()); + MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), + ConstantAsMetadata::get(B.getInt32(Minor))}); + StringRef DxilValKey = "dx.valver"; + M.addModuleFlag(llvm::Module::ModFlagBehavior::AppendUnique, DxilValKey, Val); +} +} // namespace + +void CGHLSLRuntime::finishCodeGen() { + auto &TargetOpts = CGM.getTarget().getTargetOpts(); + + llvm::Module &M = CGM.getModule(); + addDxilValVersion(TargetOpts.DxilValidatorVersion, M); +} diff --git a/clang/lib/CodeGen/CMakeLists.txt b/clang/lib/CodeGen/CMakeLists.txt --- a/clang/lib/CodeGen/CMakeLists.txt +++ b/clang/lib/CodeGen/CMakeLists.txt @@ -51,6 +51,7 @@ CGExprConstant.cpp CGExprScalar.cpp CGGPUBuiltin.cpp + CGHLSLRuntime.cpp CGLoopInfo.cpp CGNonTrivialStruct.cpp CGObjC.cpp diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h --- a/clang/lib/CodeGen/CodeGenModule.h +++ b/clang/lib/CodeGen/CodeGenModule.h @@ -85,6 +85,7 @@ class CGOpenCLRuntime; class CGOpenMPRuntime; class CGCUDARuntime; +class CGHLSLRuntime; class CoverageMappingModuleGen; class TargetCodeGenInfo; @@ -319,6 +320,7 @@ std::unique_ptr OpenCLRuntime; std::unique_ptr OpenMPRuntime; std::unique_ptr CUDARuntime; + std::unique_ptr HLSLRuntime; std::unique_ptr DebugInfo; std::unique_ptr ObjCData; llvm::MDNode *NoObjCARCExceptionsMetadata = nullptr; @@ -512,6 +514,7 @@ void createOpenCLRuntime(); void createOpenMPRuntime(); void createCUDARuntime(); + void createHLSLRuntime(); bool isTriviallyRecursive(const FunctionDecl *F); bool shouldEmitFunction(GlobalDecl GD); @@ -610,6 +613,12 @@ return *CUDARuntime; } + /// Return a reference to the configured HLSL runtime. + CGHLSLRuntime &getHLSLRuntime() { + assert(HLSLRuntime != nullptr); + return *HLSLRuntime; + } + ObjCEntrypoints &getObjCEntrypoints() const { assert(ObjCData != nullptr); return *ObjCData; diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -16,6 +16,7 @@ #include "CGCXXABI.h" #include "CGCall.h" #include "CGDebugInfo.h" +#include "CGHLSLRuntime.h" #include "CGObjCRuntime.h" #include "CGOpenCLRuntime.h" #include "CGOpenMPRuntime.h" @@ -146,6 +147,8 @@ createOpenMPRuntime(); if (LangOpts.CUDA) createCUDARuntime(); + if (LangOpts.HLSL) + createHLSLRuntime(); // Enable TBAA unless it's suppressed. ThreadSanitizer needs TBAA even at O0. if (LangOpts.Sanitize.has(SanitizerKind::Thread) || @@ -262,6 +265,10 @@ CUDARuntime.reset(CreateNVCUDARuntime(*this)); } +void CodeGenModule::createHLSLRuntime() { + HLSLRuntime.reset(new CGHLSLRuntime(*this)); +} + void CodeGenModule::addReplacement(StringRef Name, llvm::Constant *C) { Replacements[Name] = C; } @@ -832,6 +839,10 @@ } } + // HLSL related end of code gen work items. + if (LangOpts.HLSL) + getHLSLRuntime().finishCodeGen(); + if (uint32_t PLevel = Context.getLangOpts().PICLevel) { assert(PLevel < 3 && "Invalid PIC Level"); getModule().setPICLevel(static_cast(PLevel)); diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -3459,6 +3459,15 @@ } } +static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs, + types::ID InputType) { + const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version}; + + for (const auto &Arg : ForwardedArguments) + if (const auto *A = Args.getLastArg(Arg)) + A->renderAsInput(Args, CmdArgs); +} + static void RenderARCMigrateToolOptions(const Driver &D, const ArgList &Args, ArgStringList &CmdArgs) { bool ARCMTEnabled = false; @@ -6204,6 +6213,10 @@ // Forward -cl options to -cc1 RenderOpenCLOptions(Args, CmdArgs, InputType); + // Forward hlsl options to -cc1 + if (C.getDriver().IsDXCMode()) + RenderHLSLOptions(Args, CmdArgs, InputType); + if (IsHIP) { if (Args.hasFlag(options::OPT_fhip_new_launch_api, options::OPT_fno_hip_new_launch_api, true)) diff --git a/clang/lib/Driver/ToolChains/HLSL.h b/clang/lib/Driver/ToolChains/HLSL.h --- a/clang/lib/Driver/ToolChains/HLSL.h +++ b/clang/lib/Driver/ToolChains/HLSL.h @@ -26,6 +26,9 @@ } bool isPICDefaultForced() const override { return false; } + llvm::opt::DerivedArgList * + TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch, + Action::OffloadKind DeviceOffloadKind) const override; std::string ComputeEffectiveClangTriple(const llvm::opt::ArgList &Args, types::ID InputType) const override; }; diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp --- a/clang/lib/Driver/ToolChains/HLSL.cpp +++ b/clang/lib/Driver/ToolChains/HLSL.cpp @@ -108,6 +108,29 @@ return ""; } +bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) { + VersionTuple Version; + if (Version.tryParse(ValVersionStr) || Version.getBuild() || + Version.getSubminor() || !Version.getMinor()) { + D.Diag(diag::err_drv_invalid_format_dxil_validator_version) + << ValVersionStr; + return false; + } + + uint64_t Major = Version.getMajor(); + uint64_t Minor = Version.getMinor().getValue(); + if (Major == 0 && Minor != 0) { + D.Diag(diag::err_drv_invalid_empty_dxil_validator_version) << ValVersionStr; + return false; + } + VersionTuple MinVer(1, 0); + if (Version < MinVer) { + D.Diag(diag::err_drv_invalid_range_dxil_validator_version) << ValVersionStr; + return false; + } + return true; +} + } // namespace /// DirectX Toolchain @@ -131,3 +154,30 @@ return ToolChain::ComputeEffectiveClangTriple(Args, InputType); } } + +DerivedArgList * +HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch, + Action::OffloadKind DeviceOffloadKind) const { + DerivedArgList *DAL = new DerivedArgList(Args.getBaseArgs()); + + const OptTable &Opts = getDriver().getOpts(); + + for (Arg *A : Args) { + if (A->getOption().getID() == options::OPT_dxil_validator_version) { + StringRef ValVerStr = A->getValue(); + std::string ErrorMsg; + if (!isLegalValidatorVersion(ValVerStr, getDriver())) + continue; + } + DAL->append(A); + } + // Add default validator version if not set. + // TODO: remove this once read validator version from validator. + if (!DAL->hasArg(options::OPT_dxil_validator_version)) { + const StringRef DefaultValidatorVer = "1.7"; + DAL->AddSeparateArg(nullptr, + Opts.getOption(options::OPT_dxil_validator_version), + DefaultValidatorVer); + } + return DAL; +} diff --git a/clang/test/CodeGenHLSL/validator_version.hlsl b/clang/test/CodeGenHLSL/validator_version.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/validator_version.hlsl @@ -0,0 +1,10 @@ +// RUN: %clang -cc1 -S -triple dxil-pc-shadermodel6.3-library -S -emit-llvm -xhlsl -validator-version 1.1 -o - %s | FileCheck %s + +// CHECK:!"dx.valver", ![[valver:[0-9]+]]} +// CHECK:![[valver]] = !{i32 1, i32 1} + +float bar(float a, float b); + +float foo(float a, float b) { + return bar(a, b); +} diff --git a/clang/unittests/Driver/ToolChainTest.cpp b/clang/unittests/Driver/ToolChainTest.cpp --- a/clang/unittests/Driver/ToolChainTest.cpp +++ b/clang/unittests/Driver/ToolChainTest.cpp @@ -367,27 +367,28 @@ EXPECT_EQ(getDriverMode(Args[0], llvm::makeArrayRef(Args).slice(1)), "bar"); } +struct SimpleDiagnosticConsumer : public DiagnosticConsumer { + void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel, + const Diagnostic &Info) override { + if (DiagLevel == DiagnosticsEngine::Level::Error) { + Errors.emplace_back(); + Info.FormatDiagnostic(Errors.back()); + } else { + Msgs.emplace_back(); + Info.FormatDiagnostic(Msgs.back()); + } + } + void clear() override { + Msgs.clear(); + Errors.clear(); + DiagnosticConsumer::clear(); + } + std::vector> Msgs; + std::vector> Errors; +}; + TEST(DxcModeTest, TargetProfileValidation) { IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); - struct SimpleDiagnosticConsumer : public DiagnosticConsumer { - void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel, - const Diagnostic &Info) override { - if (DiagLevel == DiagnosticsEngine::Level::Error) { - Errors.emplace_back(); - Info.FormatDiagnostic(Errors.back()); - } else { - Msgs.emplace_back(); - Info.FormatDiagnostic(Msgs.back()); - } - } - void clear() override { - Msgs.clear(); - Errors.clear(); - DiagnosticConsumer::clear(); - } - std::vector> Msgs; - std::vector> Errors; - }; IntrusiveRefCntPtr InMemoryFileSystem( new llvm::vfs::InMemoryFileSystem); @@ -472,7 +473,8 @@ Triple = TC.ComputeEffectiveClangTriple(Args); EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); EXPECT_EQ(Diags.getNumErrors(), 1u); - EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : pss_6_1"); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), + "invalid profile : pss_6_1"); Diags.Clear(); DiagConsumer->clear(); @@ -481,7 +483,7 @@ Triple = TC.ComputeEffectiveClangTriple(Args); EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); EXPECT_EQ(Diags.getNumErrors(), 2u); - EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : ps_6_x"); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : ps_6_x"); Diags.Clear(); DiagConsumer->clear(); @@ -490,7 +492,8 @@ Triple = TC.ComputeEffectiveClangTriple(Args); EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); EXPECT_EQ(Diags.getNumErrors(), 3u); - EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : lib_6_1"); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), + "invalid profile : lib_6_1"); Diags.Clear(); DiagConsumer->clear(); @@ -499,7 +502,110 @@ Triple = TC.ComputeEffectiveClangTriple(Args); EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel"); EXPECT_EQ(Diags.getNumErrors(), 4u); - EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : foo"); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : foo"); + Diags.Clear(); + DiagConsumer->clear(); +} + +TEST(DxcModeTest, ValidatorVersionValidation) { + IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); + + IntrusiveRefCntPtr InMemoryFileSystem( + new llvm::vfs::InMemoryFileSystem); + + InMemoryFileSystem->addFile("foo.hlsl", 0, + llvm::MemoryBuffer::getMemBuffer("\n")); + + auto *DiagConsumer = new SimpleDiagnosticConsumer; + IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); + DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); + Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); + std::unique_ptr C( + TheDriver.BuildCompilation({"clang", "--driver-mode=dxc", "foo.hlsl"})); + EXPECT_TRUE(C); + EXPECT_TRUE(!C->containsError()); + + auto &TC = C->getDefaultToolChain(); + bool ContainsError = false; + auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + auto DAL = std::make_unique(Args); + for (auto *A : Args) + DAL->append(A); + + auto *TranslatedArgs = + TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_NE(TranslatedArgs, nullptr); + if (TranslatedArgs) { + auto *A = TranslatedArgs->getLastArg( + clang::driver::options::OPT_dxil_validator_version); + EXPECT_NE(A, nullptr); + if (A) + EXPECT_STREQ(A->getValue(), "1.1"); + } + EXPECT_EQ(Diags.getNumErrors(), 0); + + // Invalid tests. + Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) + DAL->append(A); + + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 1); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), + "invalid validator version : 0.1\nIf validator major version is " + "0, minor version must also be 0."); + Diags.Clear(); + DiagConsumer->clear(); + + Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) + DAL->append(A); + + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 2); + EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), + "invalid validator version : 1\nFormat of validator version is " + "\".\" (ex:\"1.4\")."); + Diags.Clear(); + DiagConsumer->clear(); + + Args = TheDriver.ParseArgStrings({"-validator-version", "-Tlib_6_7"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) + DAL->append(A); + + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 3); + EXPECT_STREQ( + DiagConsumer->Errors.back().c_str(), + "invalid validator version : -Tlib_6_7\nFormat of validator version is " + "\".\" (ex:\"1.4\")."); + Diags.Clear(); + DiagConsumer->clear(); + + Args = TheDriver.ParseArgStrings({"-validator-version", "foo"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) + DAL->append(A); + + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 4); + EXPECT_STREQ( + DiagConsumer->Errors.back().c_str(), + "invalid validator version : foo\nFormat of validator version is " + "\".\" (ex:\"1.4\")."); Diags.Clear(); DiagConsumer->clear(); }