diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h --- a/clang/include/clang/Sema/HLSLExternalSemaSource.h +++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "clang/Sema/ExternalSemaSource.h" +#include "clang/Sema/MultiplexExternalSemaSource.h" namespace clang { class NamespaceDecl; @@ -22,8 +23,9 @@ class HLSLExternalSemaSource : public ExternalSemaSource { Sema *SemaPtr = nullptr; - NamespaceDecl *HLSLNamespace; + NamespaceDecl *HLSLNamespace = nullptr; CXXRecordDecl *ResourceDecl; + ExternalSemaSource *ExternalSema = nullptr; using CompletionFunction = std::function; llvm::DenseMap Completions; @@ -48,6 +50,27 @@ using ExternalASTSource::CompleteType; /// Complete an incomplete HLSL builtin type void CompleteType(TagDecl *Tag) override; + void SetExternalSema(ExternalSemaSource *ExtSema) { ExternalSema = ExtSema; } +}; + +/// Members of ChainedHLSLExternalSemaSource, factored out so we can initialize +/// them before we initialize the MultiplexExternalSemaSource. +struct ChainedHLSLExternalSemaSourceMembers { + ChainedHLSLExternalSemaSourceMembers(ExternalSemaSource *ExtSema) + : ExternalSema(ExtSema) { + HLSLSema.SetExternalSema(ExtSema); + } + HLSLExternalSemaSource HLSLSema; + IntrusiveRefCntPtr ExternalSema; +}; + +class ChainedHLSLExternalSemaSource + : private ChainedHLSLExternalSemaSourceMembers, + public MultiplexExternalSemaSource { +public: + ChainedHLSLExternalSemaSource(ExternalSemaSource *ExtSema) + : ChainedHLSLExternalSemaSourceMembers(ExtSema), + MultiplexExternalSemaSource(*ExternalSema.get(), HLSLSema) {} }; } // namespace clang diff --git a/clang/lib/Frontend/FrontendAction.cpp b/clang/lib/Frontend/FrontendAction.cpp --- a/clang/lib/Frontend/FrontendAction.cpp +++ b/clang/lib/Frontend/FrontendAction.cpp @@ -1026,9 +1026,16 @@ // Setup HLSL External Sema Source if (CI.getLangOpts().HLSL && CI.hasASTContext()) { - IntrusiveRefCntPtr HLSLSema( - new HLSLExternalSemaSource()); - CI.getASTContext().setExternalSource(HLSLSema); + if (auto *SemaSource = dyn_cast_if_present( + CI.getASTContext().getExternalSource())) { + IntrusiveRefCntPtr HLSLSema( + new ChainedHLSLExternalSemaSource(SemaSource)); + CI.getASTContext().setExternalSource(HLSLSema); + } else { + IntrusiveRefCntPtr HLSLSema( + new HLSLExternalSemaSource()); + CI.getASTContext().setExternalSource(HLSLSema); + } } FailureCleanup.release(); diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -23,6 +23,17 @@ using namespace clang; using namespace hlsl; +static NamedDecl *findDecl(Sema &S, IdentifierInfo &II, + Sema::LookupNameKind Kind) { + DeclarationNameInfo NameInfo{DeclarationName{&II}, SourceLocation()}; + LookupResult R(S, NameInfo, Kind); + S.LookupName(R, S.getCurScope()); + NamedDecl *D = nullptr; + if (!R.isAmbiguous() && !R.empty()) + D = R.getRepresentativeDecl(); + return D; +} + namespace { struct TemplateParameterListBuilder; @@ -32,8 +43,13 @@ ClassTemplateDecl *Template = nullptr; NamespaceDecl *HLSLNamespace = nullptr; llvm::StringMap Fields; + bool ReusePrevDecl = false; BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) { + if (Record->isCompleteDefinition()) { + ReusePrevDecl = true; + return; + } Record->startDefinition(); Template = Record->getDescribedClassTemplate(); } @@ -42,6 +58,27 @@ : HLSLNamespace(Namespace) { ASTContext &AST = S.getASTContext(); IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier); + CXXRecordDecl *PrevRecord = nullptr; + if (NamedDecl *PrevDecl = + findDecl(S, II, Sema::LookupNameKind::LookupOrdinaryName)) + if (PrevDecl->getDeclContext() == HLSLNamespace) { + PrevRecord = llvm::dyn_cast(PrevDecl); + if (!PrevRecord) { + if (auto *PrevTemplate = + llvm::dyn_cast(PrevDecl)) { + Template = PrevTemplate; + PrevRecord = PrevTemplate->getTemplatedDecl(); + } + } + if (PrevRecord) { + Record = PrevRecord; + ReusePrevDecl = true; + // Mark ExternalLexicalStorage so complete type will be called for + // ExternalAST path. + Record->setHasExternalLexicalStorage(); + return; + } + } Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(), @@ -57,12 +94,15 @@ } ~BuiltinTypeDeclBuilder() { - if (HLSLNamespace && !Template) + if (HLSLNamespace && !Template && !ReusePrevDecl) HLSLNamespace->addDecl(Record); } BuiltinTypeDeclBuilder & addTemplateArgumentList(llvm::ArrayRef TemplateArgs) { + if (ReusePrevDecl) + return *this; + ASTContext &AST = Record->getASTContext(); auto *ParamList = @@ -187,8 +227,8 @@ BuiltinTypeDeclBuilder &completeDefinition() { assert(Record->isBeingDefined() && "Definition must be started before completing it."); - - Record->completeDefinition(); + if (!Record->isCompleteDefinition()) + Record->completeDefinition(); return *this; } @@ -207,6 +247,9 @@ TemplateParameterListBuilder & addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) { + if (Builder.ReusePrevDecl) + return *this; + unsigned Position = static_cast(Params.size()); auto *Decl = TemplateTypeParmDecl::Create( AST, Builder.Record->getDeclContext(), SourceLocation(), @@ -221,7 +264,7 @@ } BuiltinTypeDeclBuilder &finalizeTemplateArgs() { - if (Params.empty()) + if (Params.empty() || Builder.ReusePrevDecl) return Builder; auto *ParamList = TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), @@ -253,12 +296,28 @@ void HLSLExternalSemaSource::InitializeSema(Sema &S) { SemaPtr = &S; ASTContext &AST = SemaPtr->getASTContext(); + IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier); - HLSLNamespace = - NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false, - SourceLocation(), SourceLocation(), &HLSL, nullptr); - HLSLNamespace->setImplicit(true); - AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + if (ExternalSema) { + // If the translation unit has external storage force external decls to + // load. + if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage()) + (void)AST.getTranslationUnitDecl()->decls_begin(); + + NamespaceDecl *ExternalHLSL = llvm::dyn_cast_if_present( + findDecl(S, HLSL, Sema::LookupNameKind::LookupNamespaceName)); + // Try to initailize from ExternalSema. + if (ExternalHLSL && isa(ExternalHLSL->getParent())) + HLSLNamespace = ExternalHLSL; + } + + if (!HLSLNamespace) { + HLSLNamespace = NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), + false, SourceLocation(), + SourceLocation(), &HLSL, nullptr); + HLSLNamespace->setImplicit(true); + AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + } defineTrivialHLSLTypes(); forwardDeclareHLSLTypes(); @@ -278,6 +337,11 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() { ASTContext &AST = SemaPtr->getASTContext(); + IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier); + auto *PrevVector = llvm::dyn_cast_if_present( + findDecl(*SemaPtr, II, Sema::LookupNameKind::LookupOrdinaryName)); + if (PrevVector && PrevVector->getDeclContext() == HLSLNamespace) + return; llvm::SmallVector TemplateParams; @@ -302,8 +366,6 @@ TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), TemplateParams, SourceLocation(), nullptr); - IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier); - QualType AliasType = AST.getDependentSizedExtVectorType( AST.getTemplateTypeParmType(0, 0, false, TypeParam), DeclRefExpr::Create( @@ -344,9 +406,10 @@ .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) .finalizeTemplateArgs() .Record; - Completions.insert(std::make_pair( - Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, - std::placeholders::_1))); + if (!Decl->isCompleteDefinition()) + Completions.insert(std::make_pair( + Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, + std::placeholders::_1))); } void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { diff --git a/clang/test/AST/HLSL/Inputs/pch.hlsl b/clang/test/AST/HLSL/Inputs/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch.hlsl @@ -0,0 +1,4 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} diff --git a/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +RWBuffer Buf; diff --git a/clang/test/AST/HLSL/pch.hlsl b/clang/test/AST/HLSL/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a RWBuffer in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:23> col:23 Buffer 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buffer; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/pch_with_buf.hlsl b/clang/test/AST/HLSL/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch_with_buf.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -finclude-default-header -emit-pch -o %t %S/Inputs/pch_with_buf.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s | FileCheck %s + +// Make sure PCH works by using function declared in PCH header. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// Make sure buffer defined in PCH works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} col:17 imported Buf 'RWBuffer':'hlsl::RWBuffer<>' +// Make sure declare a RWBuffer in current file works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:11:1, col:23> col:23 Buf2 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buf2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +}