Index: clang/include/clang/Basic/CodeGenOptions.h =================================================================== --- clang/include/clang/Basic/CodeGenOptions.h +++ clang/include/clang/Basic/CodeGenOptions.h @@ -190,6 +190,9 @@ /// debug info. std::string DIBugsReportFilePath; + /// The validator version for dxil. + std::string DxilValidatorVersion; + /// The floating-point denormal mode to use. llvm::DenormalMode FPDenormalMode = llvm::DenormalMode::getIEEE(); Index: clang/include/clang/Basic/DiagnosticDriverKinds.td =================================================================== --- clang/include/clang/Basic/DiagnosticDriverKinds.td +++ clang/include/clang/Basic/DiagnosticDriverKinds.td @@ -660,4 +660,13 @@ def err_drv_invalid_directx_shader_module : Error< "invalid profile : %0">; +def err_drv_invalid_range_dxil_validator_version : Error< + "invalid validator version : %0\n" + "Validator version must be less than or equal to current internal version.">; +def err_drv_invalid_format_dxil_validator_version : Error< + "invalid validator version : %0\n" + "Format of validator version is \".\" (ex:\"1.4\").">; +def err_drv_invalid_empty_dxil_validator_version : Error< + "invalid validator version : %0\n" + "If validator major version is 0, minor version must also be 0.">; } Index: clang/include/clang/Driver/Options.td =================================================================== --- clang/include/clang/Driver/Options.td +++ clang/include/clang/Driver/Options.td @@ -6698,6 +6698,7 @@ def dxc_Group : OptionGroup<"">, Flags<[DXCOption]>, HelpText<"dxc compatibility options">; + class DXCJoinedOrSeparate : Option<["/", "-"], name, KIND_JOINED_OR_SEPARATE>, Group, Flags<[DXCOption, NoXarchOption]>; @@ -6709,6 +6710,11 @@ def Fo : DXCJoinedOrSeparate<"Fo">, Alias, HelpText<"Output object file.">; +def dxil_validator_version : Option<["/", "-"], "validator-version", KIND_SEPARATE>, + Group, Flags<[DXCOption, NoXarchOption, CC1Option, HelpHidden]>, + HelpText<"Override validator version for module. Format: ;Default: DXIL.dll version or current internal version.">, + MarshallingInfoString>; + def target_profile : DXCJoinedOrSeparate<"T">, MetaVarName<"">, HelpText<"Set target profile.">, Values<"ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, ps_6_6, ps_6_7," Index: clang/lib/CodeGen/CGHLSLRuntime.h =================================================================== --- /dev/null +++ clang/lib/CodeGen/CGHLSLRuntime.h @@ -0,0 +1,38 @@ +//===----- CGHLSLRuntime.h - Interface to HLSL Runtimes -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for HLSL code generation. Concrete +// subclasses of this implement code generation for specific HLSL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H +#define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H + +namespace clang { + +namespace CodeGen { + +class CodeGenModule; + +class CGHLSLRuntime { +protected: + CodeGenModule &CGM; + +public: + CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} + virtual ~CGHLSLRuntime() {} + + void finishCodeGen(); +}; + +} // namespace CodeGen +} // namespace clang + +#endif \ No newline at end of file Index: clang/lib/CodeGen/CGHLSLRuntime.cpp =================================================================== --- /dev/null +++ clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -0,0 +1,51 @@ +//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides an abstract class for HLSL code generation. Concrete +// subclasses of this implement code generation for specific HLSL +// runtime libraries. +// +//===----------------------------------------------------------------------===// + +#include "CGHLSLRuntime.h" +#include "CodeGenModule.h" +#include "clang/Basic/CodeGenOptions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" + +using namespace clang; +using namespace CodeGen; +using namespace llvm; + +namespace { +void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { + // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. + // Assume ValVersionStr is legal here. + VersionTuple Version; + if (Version.tryParse(ValVersionStr) || Version.getBuild() || + Version.getSubminor() || !Version.getMinor()) { + return; + } + + uint64_t Major = Version.getMajor(); + uint64_t Minor = Version.getMinor().getValue(); + + auto &Ctx = M.getContext(); + IRBuilder<> B(M.getContext()); + MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), + ConstantAsMetadata::get(B.getInt32(Minor))}); + StringRef DxilValKey = "dx.valver"; + M.addModuleFlag(llvm::Module::ModFlagBehavior::AppendUnique, DxilValKey, Val); +} +} // namespace + +void CGHLSLRuntime::finishCodeGen() { + auto &CGOpts = CGM.getCodeGenOpts(); + llvm::Module &M = CGM.getModule(); + addDxilValVersion(CGOpts.DxilValidatorVersion, M); +} \ No newline at end of file Index: clang/lib/CodeGen/CMakeLists.txt =================================================================== --- clang/lib/CodeGen/CMakeLists.txt +++ clang/lib/CodeGen/CMakeLists.txt @@ -51,6 +51,7 @@ CGExprConstant.cpp CGExprScalar.cpp CGGPUBuiltin.cpp + CGHLSLRuntime.cpp CGLoopInfo.cpp CGNonTrivialStruct.cpp CGObjC.cpp Index: clang/lib/CodeGen/CodeGenModule.h =================================================================== --- clang/lib/CodeGen/CodeGenModule.h +++ clang/lib/CodeGen/CodeGenModule.h @@ -85,6 +85,7 @@ class CGOpenCLRuntime; class CGOpenMPRuntime; class CGCUDARuntime; +class CGHLSLRuntime; class CoverageMappingModuleGen; class TargetCodeGenInfo; @@ -319,6 +320,7 @@ std::unique_ptr OpenCLRuntime; std::unique_ptr OpenMPRuntime; std::unique_ptr CUDARuntime; + std::unique_ptr HLSLRuntime; std::unique_ptr DebugInfo; std::unique_ptr ObjCData; llvm::MDNode *NoObjCARCExceptionsMetadata = nullptr; @@ -512,6 +514,7 @@ void createOpenCLRuntime(); void createOpenMPRuntime(); void createCUDARuntime(); + void createHLSLRuntime(); bool isTriviallyRecursive(const FunctionDecl *F); bool shouldEmitFunction(GlobalDecl GD); @@ -610,6 +613,12 @@ return *CUDARuntime; } + /// Return a reference to the configured HLSL runtime. + CGHLSLRuntime &getHLSLRuntime() { + assert(HLSLRuntime != nullptr); + return *HLSLRuntime; + } + ObjCEntrypoints &getObjCEntrypoints() const { assert(ObjCData != nullptr); return *ObjCData; Index: clang/lib/CodeGen/CodeGenModule.cpp =================================================================== --- clang/lib/CodeGen/CodeGenModule.cpp +++ clang/lib/CodeGen/CodeGenModule.cpp @@ -16,6 +16,7 @@ #include "CGCXXABI.h" #include "CGCall.h" #include "CGDebugInfo.h" +#include "CGHLSLRuntime.h" #include "CGObjCRuntime.h" #include "CGOpenCLRuntime.h" #include "CGOpenMPRuntime.h" @@ -146,6 +147,8 @@ createOpenMPRuntime(); if (LangOpts.CUDA) createCUDARuntime(); + if (LangOpts.HLSL) + createHLSLRuntime(); // Enable TBAA unless it's suppressed. ThreadSanitizer needs TBAA even at O0. if (LangOpts.Sanitize.has(SanitizerKind::Thread) || @@ -262,6 +265,10 @@ CUDARuntime.reset(CreateNVCUDARuntime(*this)); } +void CodeGenModule::createHLSLRuntime() { + HLSLRuntime.reset(new CGHLSLRuntime(*this)); +} + void CodeGenModule::addReplacement(StringRef Name, llvm::Constant *C) { Replacements[Name] = C; } @@ -832,6 +839,11 @@ } } + // HLSL related end of code gen work items. + if (LangOpts.HLSL) { + getHLSLRuntime().finishCodeGen(); + } + if (uint32_t PLevel = Context.getLangOpts().PICLevel) { assert(PLevel < 3 && "Invalid PIC Level"); getModule().setPICLevel(static_cast(PLevel)); Index: clang/lib/Driver/ToolChains/Clang.cpp =================================================================== --- clang/lib/Driver/ToolChains/Clang.cpp +++ clang/lib/Driver/ToolChains/Clang.cpp @@ -3459,6 +3459,16 @@ } } +static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs, + types::ID InputType) { + const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version}; + + for (const auto &Arg : ForwardedArguments) + if (const auto *A = Args.getLastArg(Arg)) { + A->renderAsInput(Args, CmdArgs); + } +} + static void RenderARCMigrateToolOptions(const Driver &D, const ArgList &Args, ArgStringList &CmdArgs) { bool ARCMTEnabled = false; @@ -6204,6 +6214,11 @@ // Forward -cl options to -cc1 RenderOpenCLOptions(Args, CmdArgs, InputType); + // Forward hlsl options to -cc1 + if (C.getDriver().IsDXCMode()) { + RenderHLSLOptions(Args, CmdArgs, InputType); + } + if (IsHIP) { if (Args.hasFlag(options::OPT_fhip_new_launch_api, options::OPT_fno_hip_new_launch_api, true)) Index: clang/lib/Driver/ToolChains/HLSL.h =================================================================== --- clang/lib/Driver/ToolChains/HLSL.h +++ clang/lib/Driver/ToolChains/HLSL.h @@ -26,6 +26,9 @@ } bool isPICDefaultForced() const override { return false; } + llvm::opt::DerivedArgList * + TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch, + Action::OffloadKind DeviceOffloadKind) const override; std::string ComputeEffectiveClangTriple(const llvm::opt::ArgList &Args, types::ID InputType) const override; }; Index: clang/lib/Driver/ToolChains/HLSL.cpp =================================================================== --- clang/lib/Driver/ToolChains/HLSL.cpp +++ clang/lib/Driver/ToolChains/HLSL.cpp @@ -108,6 +108,29 @@ return ""; } +bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) { + VersionTuple Version; + if (Version.tryParse(ValVersionStr) || Version.getBuild() || + Version.getSubminor() || !Version.getMinor()) { + D.Diag(diag::err_drv_invalid_format_dxil_validator_version) + << ValVersionStr; + return false; + } + + uint64_t Major = Version.getMajor(); + uint64_t Minor = Version.getMinor().getValue(); + if (Major == 0 && Minor != 0) { + D.Diag(diag::err_drv_invalid_empty_dxil_validator_version) << ValVersionStr; + return false; + } + VersionTuple MinVer(1, 0); + if (Version < MinVer) { + D.Diag(diag::err_drv_invalid_range_dxil_validator_version) << ValVersionStr; + return false; + } + return true; +} + } // namespace /// DirectX Toolchain @@ -131,3 +154,31 @@ return ToolChain::ComputeEffectiveClangTriple(Args, InputType); } } + +DerivedArgList * +HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch, + Action::OffloadKind DeviceOffloadKind) const { + DerivedArgList *DAL = new DerivedArgList(Args.getBaseArgs()); + + const OptTable &Opts = getDriver().getOpts(); + + for (Arg *A : Args) { + if (A->getOption().getID() == options::OPT_dxil_validator_version) { + StringRef ValVerStr = A->getValue(); + std::string ErrorMsg; + if (!isLegalValidatorVersion(ValVerStr, getDriver())) { + continue; + } + } + DAL->append(A); + } + // Add default validator version if not set. + // TODO: remove this once read validator version from validator. + if (!DAL->hasArg(options::OPT_dxil_validator_version)) { + const StringRef DefaultValidatorVer = "1.7"; + DAL->AddSeparateArg(nullptr, + Opts.getOption(options::OPT_dxil_validator_version), + DefaultValidatorVer); + } + return DAL; +} \ No newline at end of file Index: clang/test/CodeGenHLSL/validator_version.hlsl =================================================================== --- /dev/null +++ clang/test/CodeGenHLSL/validator_version.hlsl @@ -0,0 +1,10 @@ +// RUN: %clang -cc1 -S -triple dxil-pc-shadermodel6.3-library -S -emit-llvm -xhlsl -validator-version 1.1 -o - %s | FileCheck %s + +// CHECK:!"dx.valver", ![[valver:[0-9]+]]} +// CHECK:![[valver]] = !{i32 1, i32 1} + +float bar(float a, float b); + +float foo(float a, float b) { + return bar(a, b); +} \ No newline at end of file Index: clang/unittests/Driver/ToolChainTest.cpp =================================================================== --- clang/unittests/Driver/ToolChainTest.cpp +++ clang/unittests/Driver/ToolChainTest.cpp @@ -504,4 +504,97 @@ DiagConsumer->clear(); } +TEST(DxcModeTest, ValidatorVersionValidation) { + IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); + struct SimpleDiagnosticConsumer : public DiagnosticConsumer { + void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel, + const Diagnostic &Info) override { + if (DiagLevel == DiagnosticsEngine::Level::Error) { + Errors.emplace_back(); + Info.FormatDiagnostic(Errors.back()); + Errors.back().append({0}); + } else { + Msgs.emplace_back(); + Info.FormatDiagnostic(Msgs.back()); + Msgs.back().append({0}); + } + } + void clear() override { + Msgs.clear(); + Errors.clear(); + DiagnosticConsumer::clear(); + } + std::vector> Msgs; + std::vector> Errors; + }; + + IntrusiveRefCntPtr InMemoryFileSystem( + new llvm::vfs::InMemoryFileSystem); + + InMemoryFileSystem->addFile("foo.hlsl", 0, + llvm::MemoryBuffer::getMemBuffer("\n")); + + auto *DiagConsumer = new SimpleDiagnosticConsumer; + IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); + DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer); + Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem); + std::unique_ptr C( + TheDriver.BuildCompilation({"clang", "--driver-mode=dxc", "foo.hlsl"})); + EXPECT_TRUE(C); + EXPECT_TRUE(!C->containsError()); + + auto &TC = C->getDefaultToolChain(); + bool ContainsError = false; + auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + auto DAL = std::make_unique(Args); + for (auto *A : Args) { + DAL->append(A); + } + auto *TranslatedArgs = + TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_NE(TranslatedArgs, nullptr); + if (TranslatedArgs) { + auto *A = TranslatedArgs->getLastArg( + clang::driver::options::OPT_dxil_validator_version); + EXPECT_NE(A, nullptr); + if (A) { + EXPECT_STREQ(A->getValue(), "1.1"); + } + } + EXPECT_EQ(Diags.getNumErrors(), 0); + + // Invalid tests. + Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) { + DAL->append(A); + } + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 1); + EXPECT_STREQ(DiagConsumer->Errors.back().data(), + "invalid validator version : 0.1\nIf validator major version is " + "0, minor version must also be 0."); + Diags.Clear(); + DiagConsumer->clear(); + + Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false, + ContainsError); + EXPECT_FALSE(ContainsError); + DAL = std::make_unique(Args); + for (auto *A : Args) { + DAL->append(A); + } + TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None); + EXPECT_EQ(Diags.getNumErrors(), 2); + EXPECT_STREQ(DiagConsumer->Errors.back().data(), + "invalid validator version : 1\nFormat of validator version is " + "\".\" (ex:\"1.4\")."); + Diags.Clear(); + DiagConsumer->clear(); +} + } // end anonymous namespace.