diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h --- a/clang/include/clang/Sema/HLSLExternalSemaSource.h +++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -15,6 +15,7 @@ #include "llvm/ADT/DenseMap.h" #include "clang/Sema/ExternalSemaSource.h" +#include "clang/Sema/MultiplexExternalSemaSource.h" namespace clang { class NamespaceDecl; @@ -22,8 +23,9 @@ class HLSLExternalSemaSource : public ExternalSemaSource { Sema *SemaPtr = nullptr; - NamespaceDecl *HLSLNamespace; + NamespaceDecl *HLSLNamespace = nullptr; CXXRecordDecl *ResourceDecl; + ExternalSemaSource *ExternalSema = nullptr; using CompletionFunction = std::function; llvm::DenseMap Completions; @@ -35,6 +37,7 @@ void completeBufferType(CXXRecordDecl *Record); public: + HLSLExternalSemaSource(ExternalSemaSource *ExtSema) : ExternalSema(ExtSema) {} ~HLSLExternalSemaSource() override; /// Initialize the semantic source with the Sema instance diff --git a/clang/lib/Frontend/FrontendAction.cpp b/clang/lib/Frontend/FrontendAction.cpp --- a/clang/lib/Frontend/FrontendAction.cpp +++ b/clang/lib/Frontend/FrontendAction.cpp @@ -1026,9 +1026,18 @@ // Setup HLSL External Sema Source if (CI.getLangOpts().HLSL && CI.hasASTContext()) { - IntrusiveRefCntPtr HLSLSema( - new HLSLExternalSemaSource()); - CI.getASTContext().setExternalSource(HLSLSema); + if (auto *SemaSource = dyn_cast_if_present( + CI.getASTContext().getExternalSource())) { + IntrusiveRefCntPtr HLSLSema( + new HLSLExternalSemaSource(SemaSource)); + IntrusiveRefCntPtr MultiSema( + new MultiplexExternalSemaSource(SemaSource, HLSLSema.get())); + CI.getASTContext().setExternalSource(MultiSema); + } else { + IntrusiveRefCntPtr HLSLSema( + new HLSLExternalSemaSource(nullptr)); + CI.getASTContext().setExternalSource(HLSLSema); + } } FailureCleanup.release(); diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -23,6 +23,17 @@ using namespace clang; using namespace hlsl; +static NamedDecl *findDecl(Sema &S, IdentifierInfo &II, + Sema::LookupNameKind Kind) { + DeclarationNameInfo NameInfo{DeclarationName{&II}, SourceLocation()}; + LookupResult R(S, NameInfo, Kind); + S.LookupName(R, S.getCurScope()); + NamedDecl *D = nullptr; + if (!R.isAmbiguous() && !R.empty()) + D = R.getRepresentativeDecl(); + return D; +} + namespace { struct TemplateParameterListBuilder; @@ -30,10 +41,16 @@ struct BuiltinTypeDeclBuilder { CXXRecordDecl *Record = nullptr; ClassTemplateDecl *Template = nullptr; + ClassTemplateDecl *PrevTemplate = nullptr; NamespaceDecl *HLSLNamespace = nullptr; llvm::StringMap Fields; + bool ReusePrevDecl = false; BuiltinTypeDeclBuilder(CXXRecordDecl *R) : Record(R) { + if (Record->isCompleteDefinition()) { + ReusePrevDecl = true; + return; + } Record->startDefinition(); Template = Record->getDescribedClassTemplate(); } @@ -42,6 +59,30 @@ : HLSLNamespace(Namespace) { ASTContext &AST = S.getASTContext(); IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier); + CXXRecordDecl *PrevRecord = nullptr; + if (NamedDecl *PrevDecl = + findDecl(S, II, Sema::LookupNameKind::LookupOrdinaryName)) + if (PrevDecl->getDeclContext() == HLSLNamespace->getPreviousDecl()) { + PrevRecord = llvm::dyn_cast(PrevDecl); + if (!PrevRecord) { + if ((PrevTemplate = llvm::dyn_cast(PrevDecl))) { + PrevTemplate = PrevTemplate->getCanonicalDecl(); + Template = PrevTemplate; + PrevRecord = Template->getTemplatedDecl(); + } + } + if (PrevRecord) { + PrevRecord = PrevRecord->getCanonicalDecl(); + // Mark ExternalLexicalStorage so complete type will be called for + // ExternalAST path. + // PrevRecord->setHasExternalLexicalStorage(); + if (PrevRecord->isCompleteDefinition()) { + Record = PrevRecord; + ReusePrevDecl = true; + return; + } + } + } Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(), @@ -54,15 +95,20 @@ Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(), AttributeCommonInfo::AS_Keyword, FinalAttr::Keyword_final)); + if (PrevRecord) + Record->setPreviousDecl(PrevRecord); } ~BuiltinTypeDeclBuilder() { - if (HLSLNamespace && !Template) + if (HLSLNamespace && !Template && !ReusePrevDecl) HLSLNamespace->addDecl(Record); } BuiltinTypeDeclBuilder & addTemplateArgumentList(llvm::ArrayRef TemplateArgs) { + if (ReusePrevDecl) + return *this; + ASTContext &AST = Record->getASTContext(); auto *ParamList = @@ -283,8 +329,8 @@ BuiltinTypeDeclBuilder &completeDefinition() { assert(Record->isBeingDefined() && "Definition must be started before completing it."); - - Record->completeDefinition(); + if (!Record->isCompleteDefinition()) + Record->completeDefinition(); return *this; } @@ -303,6 +349,9 @@ TemplateParameterListBuilder & addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) { + if (Builder.ReusePrevDecl) + return *this; + unsigned Position = static_cast(Params.size()); auto *Decl = TemplateTypeParmDecl::Create( AST, Builder.Record->getDeclContext(), SourceLocation(), @@ -317,7 +366,7 @@ } BuiltinTypeDeclBuilder &finalizeTemplateArgs() { - if (Params.empty()) + if (Params.empty() || Builder.ReusePrevDecl) return Builder; auto *ParamList = TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), @@ -334,7 +383,8 @@ QualType T = Builder.Template->getInjectedClassNameSpecialization(); T = AST.getInjectedClassNameType(Builder.Record, T); - + if (Builder.PrevTemplate) + Builder.Template->setPreviousDecl(Builder.PrevTemplate); return Builder; } }; @@ -349,12 +399,29 @@ void HLSLExternalSemaSource::InitializeSema(Sema &S) { SemaPtr = &S; ASTContext &AST = SemaPtr->getASTContext(); + IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier); - HLSLNamespace = - NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false, - SourceLocation(), SourceLocation(), &HLSL, nullptr); + NamespaceDecl *PrevHLSLNamespace = nullptr; + if (ExternalSema) { + // If the translation unit has external storage force external decls to + // load. + if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage()) + (void)AST.getTranslationUnitDecl()->decls_begin(); + + PrevHLSLNamespace = llvm::dyn_cast_if_present( + findDecl(S, HLSL, Sema::LookupNameKind::LookupNamespaceName)); + // Try to initailize from ExternalSema. + if (PrevHLSLNamespace && + !isa(PrevHLSLNamespace->getParent())) + PrevHLSLNamespace = nullptr; + } + + HLSLNamespace = NamespaceDecl::Create( + AST, AST.getTranslationUnitDecl(), false, SourceLocation(), + SourceLocation(), &HLSL, PrevHLSLNamespace); HLSLNamespace->setImplicit(true); AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + defineTrivialHLSLTypes(); forwardDeclareHLSLTypes(); @@ -374,6 +441,12 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() { ASTContext &AST = SemaPtr->getASTContext(); + IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier); + auto *PrevVector = llvm::dyn_cast_if_present( + findDecl(*SemaPtr, II, Sema::LookupNameKind::LookupOrdinaryName)); + if (PrevVector && + PrevVector->getDeclContext() == HLSLNamespace->getPreviousDecl()) + return; llvm::SmallVector TemplateParams; @@ -398,8 +471,6 @@ TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), TemplateParams, SourceLocation(), nullptr); - IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier); - QualType AliasType = AST.getDependentSizedExtVectorType( AST.getTemplateTypeParmType(0, 0, false, TypeParam), DeclRefExpr::Create( @@ -440,9 +511,11 @@ .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) .finalizeTemplateArgs() .Record; - Completions.insert(std::make_pair( - Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, - std::placeholders::_1))); + Decl = Decl->getCanonicalDecl(); + if (!Decl->isCompleteDefinition()) + Completions.insert(std::make_pair( + Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, + std::placeholders::_1))); } void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { diff --git a/clang/test/AST/HLSL/Inputs/pch.hlsl b/clang/test/AST/HLSL/Inputs/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch.hlsl @@ -0,0 +1,4 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} diff --git a/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +RWBuffer Buf; diff --git a/clang/test/AST/HLSL/pch.hlsl b/clang/test/AST/HLSL/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a RWBuffer in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:23> col:23 Buffer 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buffer; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/pch_with_buf.hlsl b/clang/test/AST/HLSL/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch_with_buf.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -finclude-default-header -emit-pch -o %t %S/Inputs/pch_with_buf.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s | FileCheck %s + +// Make sure PCH works by using function declared in PCH header. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// Make sure buffer defined in PCH works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} col:17 imported Buf 'RWBuffer':'hlsl::RWBuffer<>' +// Make sure declare a RWBuffer in current file works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:11:1, col:23> col:23 Buf2 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buf2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +}