diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -158,6 +158,8 @@ - Fixed a crash in C++20 mode in Clang and Clangd when compile source with compilation errors. `Issue 53628 `_ +- The template arguments of a variable template being accessed as a + member will now be represented in the AST. Improvements to Clang's diagnostics diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h --- a/clang/include/clang/Sema/HLSLExternalSemaSource.h +++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h @@ -22,7 +22,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource { Sema *SemaPtr = nullptr; - NamespaceDecl *HLSLNamespace; + NamespaceDecl *HLSLNamespace = nullptr; CXXRecordDecl *ResourceDecl; using CompletionFunction = std::function; diff --git a/clang/lib/Driver/ToolChains/PS4CPU.cpp b/clang/lib/Driver/ToolChains/PS4CPU.cpp --- a/clang/lib/Driver/ToolChains/PS4CPU.cpp +++ b/clang/lib/Driver/ToolChains/PS4CPU.cpp @@ -159,17 +159,32 @@ const bool IsPS5 = TC.getTriple().isPS5(); assert(IsPS4 || IsPS5); + ArgStringList DbgOpts; + // This tells LTO to perform JustMyCode instrumentation. - if (UseLTO && UseJMC) { - if (IsPS4 && D.getLTOMode() == LTOK_Thin) { - CmdArgs.push_back("-lto-thin-debug-options=-enable-jmc-instrument"); - } else if (IsPS4 && D.getLTOMode() == LTOK_Full) { - CmdArgs.push_back("-lto-debug-options=-enable-jmc-instrument"); - } else if (IsPS5) { - CmdArgs.push_back("-mllvm"); - CmdArgs.push_back("-enable-jmc-instrument"); - } else - llvm_unreachable("new LTO mode?"); + if (UseLTO && UseJMC) + DbgOpts.push_back("-enable-jmc-instrument"); + + // We default to creating the arange section, but LTO does not. Enable it + // here. + if (UseLTO) + DbgOpts.push_back("-generate-arange-section"); + + if (UseLTO) { + if (IsPS4) { + StringRef F = (D.getLTOMode() == LTOK_Thin) ? + "-lto-thin-debug-options=" : "-lto-debug-options="; + F = makeArgString(Args, F.data(), DbgOpts.front(), ""); + DbgOpts.erase(DbgOpts.begin()); + for (auto X : DbgOpts) + F = makeArgString(Args, F.data(), " ", X); + CmdArgs.push_back(F.data()); + } else { + for (auto D : DbgOpts) { + CmdArgs.push_back("-mllvm"); + CmdArgs.push_back(D); + } + } } if (!Args.hasArg(options::OPT_nostdlib, options::OPT_nodefaultlibs)) diff --git a/clang/lib/Frontend/FrontendAction.cpp b/clang/lib/Frontend/FrontendAction.cpp --- a/clang/lib/Frontend/FrontendAction.cpp +++ b/clang/lib/Frontend/FrontendAction.cpp @@ -28,6 +28,7 @@ #include "clang/Lex/PreprocessorOptions.h" #include "clang/Parse/ParseAST.h" #include "clang/Sema/HLSLExternalSemaSource.h" +#include "clang/Sema/MultiplexExternalSemaSource.h" #include "clang/Serialization/ASTDeserializationListener.h" #include "clang/Serialization/ASTReader.h" #include "clang/Serialization/GlobalModuleIndex.h" @@ -1026,9 +1027,15 @@ // Setup HLSL External Sema Source if (CI.getLangOpts().HLSL && CI.hasASTContext()) { - IntrusiveRefCntPtr HLSLSema( + IntrusiveRefCntPtr HLSLSema( new HLSLExternalSemaSource()); - CI.getASTContext().setExternalSource(HLSLSema); + if (auto *SemaSource = dyn_cast_if_present( + CI.getASTContext().getExternalSource())) { + IntrusiveRefCntPtr MultiSema( + new MultiplexExternalSemaSource(SemaSource, HLSLSema.get())); + CI.getASTContext().setExternalSource(MultiSema); + } else + CI.getASTContext().setExternalSource(HLSLSema); } FailureCleanup.release(); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -15,11 +15,6 @@ // abs builtins -__attribute__((clang_builtin_alias(__builtin_abs))) int abs(int In); -__attribute__((clang_builtin_alias(__builtin_labs))) int64_t abs(int64_t In); -__attribute__((clang_builtin_alias(__builtin_fabsf))) float abs(float In); -__attribute__((clang_builtin_alias(__builtin_fabs))) double abs(double In); - #ifdef __HLSL_ENABLE_16_BIT __attribute__((clang_builtin_alias(__builtin_elementwise_abs))) int16_t abs(int16_t); @@ -76,5 +71,33 @@ __attribute__((clang_builtin_alias(__builtin_sqrtf16))) half sqrt(half In); #endif +// ceil builtins +#ifdef __HLSL_ENABLE_16_BIT +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) half ceil(half); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +half2 ceil(half2); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +half3 ceil(half3); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +half4 ceil(half4); +#endif + +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) float +ceil(float); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +float2 ceil(float2); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +float3 ceil(float3); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +float4 ceil(float4); + +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) double +ceil(double); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +double2 ceil(double2); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +double3 ceil(double3); +__attribute__((clang_builtin_alias(__builtin_elementwise_ceil))) +double4 ceil(double4); #endif //_HLSL_HLSL_INTRINSICS_H_ diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -30,6 +30,7 @@ struct BuiltinTypeDeclBuilder { CXXRecordDecl *Record = nullptr; ClassTemplateDecl *Template = nullptr; + ClassTemplateDecl *PrevTemplate = nullptr; NamespaceDecl *HLSLNamespace = nullptr; llvm::StringMap Fields; @@ -43,48 +44,46 @@ ASTContext &AST = S.getASTContext(); IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier); + LookupResult Result(S, &II, SourceLocation(), Sema::LookupTagName); + CXXRecordDecl *PrevDecl = nullptr; + if (S.LookupQualifiedName(Result, HLSLNamespace)) { + NamedDecl *Found = Result.getFoundDecl(); + if (auto *TD = dyn_cast(Found)) { + PrevDecl = TD->getTemplatedDecl(); + PrevTemplate = TD; + } else + PrevDecl = dyn_cast(Found); + assert(PrevDecl && "Unexpected lookup result type."); + } + + if (PrevDecl && PrevDecl->isCompleteDefinition()) { + Record = PrevDecl; + return; + } + Record = CXXRecordDecl::Create(AST, TagDecl::TagKind::TTK_Class, HLSLNamespace, SourceLocation(), - SourceLocation(), &II, nullptr, true); + SourceLocation(), &II, PrevDecl, true); Record->setImplicit(true); Record->setLexicalDeclContext(HLSLNamespace); Record->setHasExternalLexicalStorage(); - // Don't let anyone derive from built-in types + // Don't let anyone derive from built-in types. Record->addAttr(FinalAttr::CreateImplicit(AST, SourceRange(), AttributeCommonInfo::AS_Keyword, FinalAttr::Keyword_final)); } ~BuiltinTypeDeclBuilder() { - if (HLSLNamespace && !Template) + if (HLSLNamespace && !Template && Record->getDeclContext() == HLSLNamespace) HLSLNamespace->addDecl(Record); } - BuiltinTypeDeclBuilder & - addTemplateArgumentList(llvm::ArrayRef TemplateArgs) { - ASTContext &AST = Record->getASTContext(); - - auto *ParamList = - TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(), - TemplateArgs, SourceLocation(), nullptr); - Template = ClassTemplateDecl::Create( - AST, Record->getDeclContext(), SourceLocation(), - DeclarationName(Record->getIdentifier()), ParamList, Record); - Record->setDescribedClassTemplate(Template); - Template->setImplicit(true); - Template->setLexicalDeclContext(Record->getDeclContext()); - Record->getDeclContext()->addDecl(Template); - - // Requesting the class name specialization will fault in required types. - QualType T = Template->getInjectedClassNameSpecialization(); - T = AST.getInjectedClassNameType(Record, T); - return *this; - } - BuiltinTypeDeclBuilder & addMemberVariable(StringRef Name, QualType Type, AccessSpecifier Access = AccessSpecifier::AS_private) { + if (Record->isCompleteDefinition()) + return *this; assert(Record->isBeingDefined() && "Definition must be started before adding members!"); ASTContext &AST = Record->getASTContext(); @@ -104,6 +103,8 @@ BuiltinTypeDeclBuilder & addHandleMember(AccessSpecifier Access = AccessSpecifier::AS_private) { + if (Record->isCompleteDefinition()) + return *this; QualType Ty = Record->getASTContext().VoidPtrTy; if (Template) { if (const auto *TTD = dyn_cast( @@ -116,6 +117,8 @@ BuiltinTypeDeclBuilder & annotateResourceClass(HLSLResourceAttr::ResourceClass RC) { + if (Record->isCompleteDefinition()) + return *this; Record->addAttr( HLSLResourceAttr::CreateImplicit(Record->getASTContext(), RC)); return *this; @@ -147,6 +150,8 @@ BuiltinTypeDeclBuilder &addDefaultHandleConstructor(Sema &S, ResourceClass RC) { + if (Record->isCompleteDefinition()) + return *this; ASTContext &AST = Record->getASTContext(); QualType ConstructorType = @@ -197,12 +202,16 @@ } BuiltinTypeDeclBuilder &addArraySubscriptOperators() { + if (Record->isCompleteDefinition()) + return *this; addArraySubscriptOperator(true); addArraySubscriptOperator(false); return *this; } BuiltinTypeDeclBuilder &addArraySubscriptOperator(bool IsConst) { + if (Record->isCompleteDefinition()) + return *this; assert(Fields.count("h") > 0 && "Subscript operator must be added after the handle."); @@ -279,11 +288,15 @@ } BuiltinTypeDeclBuilder &startDefinition() { + if (Record->isCompleteDefinition()) + return *this; Record->startDefinition(); return *this; } BuiltinTypeDeclBuilder &completeDefinition() { + if (Record->isCompleteDefinition()) + return *this; assert(Record->isBeingDefined() && "Definition must be started before completing it."); @@ -306,6 +319,8 @@ TemplateParameterListBuilder & addTypeParameter(StringRef Name, QualType DefaultValue = QualType()) { + if (Builder.Record->isCompleteDefinition()) + return *this; unsigned Position = static_cast(Params.size()); auto *Decl = TemplateTypeParmDecl::Create( AST, Builder.Record->getDeclContext(), SourceLocation(), @@ -332,6 +347,9 @@ Builder.Record->setDescribedClassTemplate(Builder.Template); Builder.Template->setImplicit(true); Builder.Template->setLexicalDeclContext(Builder.Record->getDeclContext()); + // NOTE: setPreviousDecl before addDecl so new decl replace old decl when + // make visible. + Builder.Template->setPreviousDecl(Builder.PrevTemplate); Builder.Record->getDeclContext()->addDecl(Builder.Template); Params.clear(); @@ -352,12 +370,24 @@ void HLSLExternalSemaSource::InitializeSema(Sema &S) { SemaPtr = &S; ASTContext &AST = SemaPtr->getASTContext(); + // If the translation unit has external storage force external decls to load. + if (AST.getTranslationUnitDecl()->hasExternalLexicalStorage()) + (void)AST.getTranslationUnitDecl()->decls_begin(); + IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier); - HLSLNamespace = - NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false, - SourceLocation(), SourceLocation(), &HLSL, nullptr); + LookupResult Result(S, &HLSL, SourceLocation(), Sema::LookupNamespaceName); + NamespaceDecl *PrevDecl = nullptr; + if (S.LookupQualifiedName(Result, AST.getTranslationUnitDecl())) + PrevDecl = Result.getAsSingle(); + HLSLNamespace = NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), + false, SourceLocation(), + SourceLocation(), &HLSL, PrevDecl); HLSLNamespace->setImplicit(true); + HLSLNamespace->setHasExternalLexicalStorage(); AST.getTranslationUnitDecl()->addDecl(HLSLNamespace); + + // Force external decls in the HLSL namespace to load from the PCH. + (void)HLSLNamespace->getCanonicalDecl()->decls_begin(); defineTrivialHLSLTypes(); forwardDeclareHLSLTypes(); @@ -443,9 +473,11 @@ .addTypeParameter("element_type", SemaPtr->getASTContext().FloatTy) .finalizeTemplateArgs() .Record; - Completions.insert(std::make_pair( - Decl, std::bind(&HLSLExternalSemaSource::completeBufferType, this, - std::placeholders::_1))); + if (!Decl->isCompleteDefinition()) + Completions.insert( + std::make_pair(Decl->getCanonicalDecl(), + std::bind(&HLSLExternalSemaSource::completeBufferType, + this, std::placeholders::_1))); } void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) { @@ -457,6 +489,7 @@ // declaration and complete that. if (auto TDecl = dyn_cast(Record)) Record = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + Record = Record->getCanonicalDecl(); auto It = Completions.find(Record); if (It == Completions.end()) return; diff --git a/clang/lib/Sema/SemaExprMember.cpp b/clang/lib/Sema/SemaExprMember.cpp --- a/clang/lib/Sema/SemaExprMember.cpp +++ b/clang/lib/Sema/SemaExprMember.cpp @@ -1161,10 +1161,10 @@ if (!Var->getTemplateSpecializationKind()) Var->setTemplateSpecializationKind(TSK_ImplicitInstantiation, MemberLoc); - return BuildMemberExpr( - BaseExpr, IsArrow, OpLoc, &SS, TemplateKWLoc, Var, FoundDecl, - /*HadMultipleCandidates=*/false, MemberNameInfo, - Var->getType().getNonReferenceType(), VK_LValue, OK_Ordinary); + return BuildMemberExpr(BaseExpr, IsArrow, OpLoc, &SS, TemplateKWLoc, Var, + FoundDecl, /*HadMultipleCandidates=*/false, + MemberNameInfo, Var->getType().getNonReferenceType(), + VK_LValue, OK_Ordinary, TemplateArgs); } // We found something that we didn't expect. Complain. diff --git a/clang/test/AST/HLSL/Inputs/pch.hlsl b/clang/test/AST/HLSL/Inputs/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch.hlsl @@ -0,0 +1,4 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} diff --git a/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_with_buf.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +RWBuffer Buf; diff --git a/clang/test/AST/HLSL/pch.hlsl b/clang/test/AST/HLSL/pch.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a RWBuffer in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:23> col:23 Buffer 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buffer; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/pch_with_buf.hlsl b/clang/test/AST/HLSL/pch_with_buf.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/AST/HLSL/pch_with_buf.hlsl @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -finclude-default-header -emit-pch -o %t %S/Inputs/pch_with_buf.hlsl +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -fsyntax-only -ast-dump-all %s | FileCheck %s + +// Make sure PCH works by using function declared in PCH header. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// Make sure buffer defined in PCH works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} col:17 imported Buf 'RWBuffer':'hlsl::RWBuffer<>' +// Make sure declare a RWBuffer in current file works. +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:11:1, col:23> col:23 Buf2 'hlsl::RWBuffer':'hlsl::RWBuffer<>' +hlsl::RWBuffer Buf2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} 'float2':'float __attribute__((ext_vector_type(2)))' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} 'float2 (*)(float2, float2)' +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/CodeGenHLSL/builtins/ceil.hlsl b/clang/test/CodeGenHLSL/builtins/ceil.hlsl new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/ceil.hlsl @@ -0,0 +1,78 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ +// RUN: -D__HLSL_ENABLE_16_BIT -o - | FileCheck %s --check-prefix=NO_HALF + + +// CHECK: define noundef half @ +// CHECK: call half @llvm.ceil.f16( +// NO_HALF: define noundef float @"?test_ceil_half@@YA$halff@$halff@@Z"( +// NO_HALF: call float @llvm.ceil.f32(float %0) +half test_ceil_half ( half p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <2 x half> @ +// CHECK: call <2 x half> @llvm.ceil.v2f16( +// NO_HALF: define noundef <2 x float> @"?test_ceil_half2@@YAT?$__vector@$halff@$01@__clang@@T12@@Z"( +// NO_HALF: call <2 x float> @llvm.ceil.v2f32( +half2 test_ceil_half2 ( half2 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <3 x half> @ +// CHECK: call <3 x half> @llvm.ceil.v3f16( +// NO_HALF: define noundef <3 x float> @"?test_ceil_half3@@YAT?$__vector@$halff@$02@__clang@@T12@@Z"( +// NO_HALF: call <3 x float> @llvm.ceil.v3f32( +half3 test_ceil_half3 ( half3 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <4 x half> @ +// CHECK: call <4 x half> @llvm.ceil.v4f16( +// NO_HALF: define noundef <4 x float> @"?test_ceil_half4@@YAT?$__vector@$halff@$03@__clang@@T12@@Z"( +// NO_HALF: call <4 x float> @llvm.ceil.v4f32( +half4 test_ceil_half4 ( half4 p0 ) { + return ceil ( p0 ); +} + +// CHECK: define noundef float @ +// CHECK: call float @llvm.ceil.f32( +float test_ceil_float ( float p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <2 x float> @ +// CHECK: call <2 x float> @llvm.ceil.v2f32( +float2 test_ceil_float2 ( float2 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <3 x float> @ +// CHECK: call <3 x float> @llvm.ceil.v3f32( +float3 test_ceil_float3 ( float3 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <4 x float> @ +// CHECK: call <4 x float> @llvm.ceil.v4f32( +float4 test_ceil_float4 ( float4 p0 ) { + return ceil ( p0 ); +} + +// CHECK: define noundef double @ +// CHECK: call double @llvm.ceil.f64( +double test_ceil_double ( double p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <2 x double> @ +// CHECK: call <2 x double> @llvm.ceil.v2f64( +double2 test_ceil_double2 ( double2 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <3 x double> @ +// CHECK: call <3 x double> @llvm.ceil.v3f64( +double3 test_ceil_double3 ( double3 p0 ) { + return ceil ( p0 ); +} +// CHECK: define noundef <4 x double> @ +// CHECK: call <4 x double> @llvm.ceil.v4f64( +double4 test_ceil_double4 ( double4 p0 ) { + return ceil ( p0 ); +} diff --git a/clang/test/Driver/debug-options.c b/clang/test/Driver/debug-options.c --- a/clang/test/Driver/debug-options.c +++ b/clang/test/Driver/debug-options.c @@ -107,6 +107,14 @@ // RUN: | FileCheck -check-prefix=CI %s // RUN: %clang -### -c %s -gsce -target x86_64-unknown-linux 2>&1 \ // RUN: | FileCheck -check-prefix=NOCI %s +// RUN: %clang -### %s -g -flto=thin -target x86_64-scei-ps4 2>&1 \ +// RUN: | FileCheck -check-prefix=SNLDTLTOGARANGE %s +// RUN: %clang -### %s -g -flto=full -target x86_64-scei-ps4 2>&1 \ +// RUN: | FileCheck -check-prefix=SNLDFLTOGARANGE %s +// RUN: %clang -### %s -g -flto -target x86_64-scei-ps5 2>&1 \ +// RUN: | FileCheck -check-prefix=LLDGARANGE %s +// RUN: %clang -### %s -g -target x86_64-scei-ps5 2>&1 \ +// RUN: | FileCheck -check-prefix=LDGARANGE %s // On the AIX, -g defaults to -gdbx and limited debug info. // RUN: %clang -### -c -g %s -target powerpc-ibm-aix-xcoff 2>&1 \ @@ -365,6 +373,13 @@ // NOPUB-NOT: -ggnu-pubnames // NOPUB-NOT: -gpubnames // + +// LDGARANGE: {{".*ld.*"}} {{.*}} +// LDGARANGE-NOT: "-generate-arange-section" +// LLDGARANGE: {{".*lld.*"}} {{.*}} "-generate-arange-section" +// SNLDTLTOGARANGE: {{".*orbis-ld.*"}} {{.*}} "-lto-thin-debug-options=-generate-arange-section" +// SNLDFLTOGARANGE: {{".*orbis-ld.*"}} {{.*}} "-lto-debug-options=-generate-arange-section" + // PUB: -gpubnames // // RNGBSE: -fdebug-ranges-base-address diff --git a/clang/test/Driver/ps4-ps5-linker-jmc.c b/clang/test/Driver/ps4-ps5-linker-jmc.c --- a/clang/test/Driver/ps4-ps5-linker-jmc.c +++ b/clang/test/Driver/ps4-ps5-linker-jmc.c @@ -6,10 +6,10 @@ // RUN: %clang --target=x86_64-scei-ps5 -fjmc %s -### 2>&1 | FileCheck --check-prefixes=CHECK-PS5,CHECK-PS5-LIB %s // RUN: %clang --target=x86_64-scei-ps5 -flto -fjmc %s -### 2>&1 | FileCheck --check-prefixes=CHECK-PS5-LTO,CHECK-PS5-LIB %s -// CHECK-PS4-NOT: "-enable-jmc-instrument" +// CHECK-PS4-NOT: -enable-jmc-instrument -// CHECK-PS4-THIN-LTO: "-lto-thin-debug-options=-enable-jmc-instrument" -// CHECK-PS4-FULL-LTO: "-lto-debug-options=-enable-jmc-instrument" +// CHECK-PS4-THIN-LTO: -lto-thin-debug-options=-enable-jmc-instrument +// CHECK-PS4-FULL-LTO: -lto-debug-options=-enable-jmc-instrument // CHECK-PS5-NOT: "-enable-jmc-instrument" diff --git a/clang/test/SemaCXX/cxx1z-ast-print.cpp b/clang/test/SemaCXX/cxx1z-ast-print.cpp --- a/clang/test/SemaCXX/cxx1z-ast-print.cpp +++ b/clang/test/SemaCXX/cxx1z-ast-print.cpp @@ -4,7 +4,7 @@ template static int x; // expected-note {{forward declaration of template entity is here}} template static int y; // expected-note {{forward declaration of template entity is here}} }; -// CHECK: int k = TypeSuffix().x + TypeSuffix().y; +// CHECK: int k = TypeSuffix().x<0L> + TypeSuffix().y<0L>; int k = TypeSuffix().x<0L> + TypeSuffix().y<0L>; // expected-warning {{instantiation of variable 'TypeSuffix::x<0>' required here, but no definition is available}} \ // expected-note {{add an explicit instantiation declaration to suppress this warning if 'TypeSuffix::x<0>' is explicitly instantiated in another translation unit}} \ // expected-warning {{instantiation of variable 'TypeSuffix::y<0L>' required here, but no definition is available}} \ diff --git a/clang/unittests/Tooling/Syntax/BuildTreeTest.cpp b/clang/unittests/Tooling/Syntax/BuildTreeTest.cpp --- a/clang/unittests/Tooling/Syntax/BuildTreeTest.cpp +++ b/clang/unittests/Tooling/Syntax/BuildTreeTest.cpp @@ -2262,8 +2262,6 @@ template static constexpr T x = 42; }; -// FIXME: `` should be a child of `MemberExpression` and `;` of -// `ExpressionStatement`. This is a bug in clang, in `getSourceRange` methods. void test(S s) [[{ s.x; }]] @@ -2272,18 +2270,18 @@ CompoundStatement |-'{' OpenParen |-ExpressionStatement Statement -| `-MemberExpression Expression -| |-IdExpression Object -| | `-UnqualifiedId UnqualifiedId -| | `-'s' -| |-'.' AccessToken -| `-IdExpression Member -| `-UnqualifiedId UnqualifiedId -| `-'x' -|-'<' -|-'int' -|-'>' -|-';' +| |-MemberExpression Expression +| | |-IdExpression Object +| | | `-UnqualifiedId UnqualifiedId +| | | `-'s' +| | |-'.' AccessToken +| | `-IdExpression Member +| | `-UnqualifiedId UnqualifiedId +| | |-'x' +| | |-'<' +| | |-'int' +| | `-'>' +| `-';' `-'}' CloseParen )txt"})); } diff --git a/flang/test/Semantics/atomic01.f90 b/flang/test/Semantics/atomic01.f90 new file mode 100644 --- /dev/null +++ b/flang/test/Semantics/atomic01.f90 @@ -0,0 +1,93 @@ +! RUN: %python %S/test_errors.py %s %flang_fc1 +! XFAIL: * +! This test checks for semantic errors in atomic_add() subroutine based on the +! statement specification in section 16.9.20 of the Fortran 2018 standard. + +program test_atomic_add + use iso_fortran_env, only : atomic_int_kind + implicit none + + integer(kind=atomic_int_kind) atom_object[*], atom_array(2)[*], quantity, array(1), coarray[*], non_coarray + integer non_atom_object[*], non_atom, non_scalar(1), status, stat_array(1), coindexed[*] + logical non_integer + + !___ standard-conforming calls with required arguments _______ + + call atomic_add(atom_object, quantity) + call atomic_add(atom_object[1], quantity) + call atomic_add(atom_array(1), quantity) + call atomic_add(atom_array(1)[1], quantity) + call atomic_add(atom_object, array(1)) + call atomic_add(atom_object, coarray[1]) + call atomic_add(atom=atom_object, value=quantity) + call atomic_add(value=quantity, atom=atom_object) + + !___ standard-conforming calls with all arguments ____________ + call atomic_add(atom_object, quantity, status) + call atomic_add(atom_object, quantity, stat_array(1)) + call atomic_add(atom=atom_object, value=quantity, stat=status) + call atomic_add(stat=status, value=quantity, atom=atom_object) + + !___ non-standard-conforming calls _______ + + ! atom must be of kind atomic_int_kind + call atomic_add(non_atom_object, quantity) + + ! atom must be a coarray + call atomic_add(non_coarray, quantity) + + ! atom must be a scalar variable + call atomic_add(atom_array, quantity) + + ! atom has an unknown keyword argument + call atomic_add(atoms=atom_object, value=quantity) + + ! atom has an argument mismatch + call atomic_add(atom=non_atom_object, value=quantity) + + ! value must be an integer + call atomic_add(atom_object, non_integer) + + ! value must be an integer scalar + call atomic_add(atom_object, array) + + ! value must be of kind atomic_int_kind + call atomic_add(atom_object, non_atom) + + ! value has an unknown keyword argument + call atomic_add(atom_object, values=quantity) + + ! value has an argument mismatch + call atomic_add(atom_object, value=non_integer) + + ! stat must be an integer + call atomic_add(atom_object, quantity, non_integer) + + ! stat must be an integer scalar + call atomic_add(atom_object, quantity, non_scalar) + + ! stat is an intent(out) argument + call atomic_add(atom_object, quantity, 8) + + ! stat has an unknown keyword argument + call atomic_add(atom_object, quantity, statuses=status) + + ! stat has an argument mismatch + call atomic_add(atom_object, quantity, stat=non_integer) + + ! stat must not be coindexed + call atomic_add(atom_object, quantity, coindexed[1]) + + ! Too many arguments + call atomic_add(atom_object, quantity, status, stat_array(1)) + + ! Repeated atom keyword + call atomic_add(atom=atom_object, atom=atom_array(1), value=quantity) + + ! Repeated value keyword + call atomic_add(atom=atom_object, value=quantity, value=array(1)) + + ! Repeated stat keyword + call atomic_add(atom=atom_object, value=quantity, stat=status, stat=stat_array(1)) + +end program test_atomic_add diff --git a/lld/COFF/Config.h b/lld/COFF/Config.h --- a/lld/COFF/Config.h +++ b/lld/COFF/Config.h @@ -208,6 +208,9 @@ // Used for /map. std::string mapFile; + // Used for /mapinfo. + bool mapInfo = false; + // Used for /thinlto-index-only: llvm::StringRef thinLTOIndexOnlyArg; diff --git a/lld/COFF/Driver.cpp b/lld/COFF/Driver.cpp --- a/lld/COFF/Driver.cpp +++ b/lld/COFF/Driver.cpp @@ -1922,6 +1922,16 @@ config->lldmapFile = getMapFile(args, OPT_lldmap, OPT_lldmap_file); config->mapFile = getMapFile(args, OPT_map, OPT_map_file); + if (config->mapFile != "" && args.hasArg(OPT_map_info)) { + for (auto *arg : args.filtered(OPT_map_info)) { + std::string s = StringRef(arg->getValue()).lower(); + if (s == "exports") + config->mapInfo = true; + else + error("unknown option: /mapinfo:" + s); + } + } + if (config->lldmapFile != "" && config->lldmapFile == config->mapFile) { warn("/lldmap and /map have the same output file '" + config->mapFile + "'.\n>>> ignoring /lldmap"); diff --git a/lld/COFF/MapFile.cpp b/lld/COFF/MapFile.cpp --- a/lld/COFF/MapFile.cpp +++ b/lld/COFF/MapFile.cpp @@ -315,6 +315,19 @@ for (Defined *sym : staticSyms) os << staticSymStr[sym] << '\n'; + // Print out the exported functions + if (config->mapInfo) { + os << "\n"; + os << " Exports\n"; + os << "\n"; + os << " ordinal name\n\n"; + for (Export &e : config->exports) { + os << format(" %7d", e.ordinal) << " " << e.name << "\n"; + if (!e.extName.empty() && e.extName != e.name) + os << " exported name: " << e.extName << "\n"; + } + } + t4.stop(); t1.stop(); } diff --git a/lld/COFF/Options.td b/lld/COFF/Options.td --- a/lld/COFF/Options.td +++ b/lld/COFF/Options.td @@ -287,6 +287,7 @@ def lldmap_file : P_priv<"lldmap">; def map : F<"map">; def map_file : P_priv<"map">; +def map_info : P<"mapinfo", "Include the specified information in a map file">; def show_timing : F<"time">; def summary : F<"summary">; diff --git a/lld/test/COFF/map.test b/lld/test/COFF/map.test --- a/lld/test/COFF/map.test +++ b/lld/test/COFF/map.test @@ -8,6 +8,9 @@ # RUN: lld-link /out:%t.exe /entry:main %t.obj %t-dll.lib /map /lldmap:%T/foo-lld.map # RUN: FileCheck -check-prefix=MAP -strict-whitespace %s < %t.map # RUN: FileCheck -check-prefix=LLDMAP -strict-whitespace %s < %T/foo-lld.map +# RUN: lld-link /out:%t.dll /dll %t-dll.obj /export:exportfn1 \ +# RUN: /export:foo=exportfn2 /map /mapinfo:exports +# RUN: FileCheck -check-prefix=MAPINFO -strict-whitespace %s < %t.map # MAP: {{.*}} # MAP-EMPTY: @@ -38,3 +41,12 @@ # LLDMAP-NEXT: 00001000 00000026 4096 .text # LLDMAP-NEXT: 00001000 00000008 4 {{.*}}map.test.tmp.obj:(.text) # LLDMAP-NEXT: 00001000 00000000 0 main + +# MAPINFO: Exports +# MAPINFO-EMPTY: +# MAPINFO-NEXT: ordinal name +# MAPINFO-EMPTY: +# MAPINFO-NEXT: 1 exportfn1 +# MAPINFO-NEXT: 2 exportfn3 +# MAPINFO-NEXT: 3 exportfn2 +# MAPINFO-NEXT: exported name: foo diff --git a/llvm/runtimes/CMakeLists.txt b/llvm/runtimes/CMakeLists.txt --- a/llvm/runtimes/CMakeLists.txt +++ b/llvm/runtimes/CMakeLists.txt @@ -366,7 +366,7 @@ if(runtimes) # Create a runtimes target that uses this file as its top-level CMake file. # The runtimes target is a configuration of all the runtime libraries - # together in a single CMake invocaiton. + # together in a single CMake invocation. set(extra_deps "") if("openmp" IN_LIST LLVM_ENABLE_RUNTIMES) if(TARGET opt) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -110,6 +110,10 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into +/// linalg.fill(%cst, tensor.extract_slice(%init)). +void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); + /// Return true if two `linalg.generic` operations with producer/consumer /// relationship through `fusedOperand` can be fused using elementwise op /// fusion. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4221,6 +4221,9 @@ def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; +def SPV_OC_OpPtrCastToGeneric : I32EnumAttrCase<"OpPtrCastToGeneric", 121>; +def SPV_OC_OpGenericCastToPtr : I32EnumAttrCase<"OpGenericCastToPtr", 122>; +def SPV_OC_OpGenericCastToPtrExplicit : I32EnumAttrCase<"OpGenericCastToPtrExplicit", 123>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpSNegate : I32EnumAttrCase<"OpSNegate", 126>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; @@ -4372,7 +4375,8 @@ SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather, SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpPtrCastToGeneric, + SPV_OC_OpGenericCastToPtr, SPV_OC_OpGenericCastToPtrExplicit, SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -331,4 +331,144 @@ }]; } +// ----- +def SPV_PtrCastToGenericOp : SPV_Op<"PtrCastToGeneric", [NoSideEffect]> { + let summary = "Convert a pointer’s Storage Class to Generic."; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be Generic. + + Pointer must point to the Workgroup, CrossWorkgroup, or Function Storage + Class. + + Result Type and Pointer must point to the same type. + + + + #### Example: + + ```mlir + %1 = spv.PtrCastToGenericOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $pointer attr-dict `:` type($pointer) `to` type($result) + }]; +} + +// ----- + +def SPV_GenericCastToPtrOp : SPV_Op<"GenericCastToPtr", [NoSideEffect]> { + let summary = "Convert a pointer’s Storage Class to a non-Generic class."; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be + Workgroup, CrossWorkgroup, or Function. + + Pointer must point to the Generic Storage Class. + + Result Type and Pointer must point to the same type. + + + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $pointer attr-dict `:` type($pointer) `to` type($result) + }]; +} + +// ----- + +def SPV_GenericCastToPtrExplicitOp : SPV_Op<"GenericCastToPtrExplicit", [NoSideEffect]> { + let summary = [{ + Attempts to explicitly convert Pointer to Storage storage-class pointer + value. + }]; + + let description = [{ + Result Type must be an OpTypePointer. Its Storage Class must be Storage. + + Pointer must have a type of OpTypePointer whose Type is the same as the + Type of Result Type.Pointer must point to the Generic Storage Class. If + the cast fails, the instruction result is an OpConstantNull pointer in + the Storage Storage Class. + + Storage must be one of the following literal values from Storage Class: + Workgroup, CrossWorkgroup, or Function. + + + + ``` + [TODO] + ```mlir + + #### Example: + + ```mlir + %1 = spv.GenericCastToPtrExplicitOp %0 : !spv.ptr to + !spv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + + let arguments = (ins + SPV_AnyPtr:$pointer + ); + + let results = (outs + SPV_AnyPtr:$result + ); + + let assemblyFormat = [{ + $pointer attr-dict `:` type($pointer) `to` type($result) + }]; + + let autogenSerialization = 0; +} + #endif // MLIR_DIALECT_SPIRV_IR_CAST_OPS diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -325,6 +326,15 @@ MLIRContext *context = &getContext(); Operation *op = getOperation(); + if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) { + spirv::TargetEnv targetEnv(attr); + if (targetEnv.allows(spirv::Capability::Kernel)) { + memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass; + } else if (targetEnv.allows(spirv::Capability::Shader)) { + memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass; + } + } + auto target = spirv::getMemorySpaceToStorageClassTarget(*context); spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -272,6 +272,8 @@ } }; +template struct VectorReductionPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -317,18 +319,18 @@ #define INT_OR_FLOAT_CASE(kind, fop) \ case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ + result = rewriter.create(loc, resultType, result, next); \ break INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); - INT_OR_FLOAT_CASE(MAXF, GLFMaxOp); - INT_OR_FLOAT_CASE(MINF, GLFMinOp); - INT_OR_FLOAT_CASE(MINUI, GLUMinOp); - INT_OR_FLOAT_CASE(MINSI, GLSMinOp); - INT_OR_FLOAT_CASE(MAXUI, GLUMaxOp); - INT_OR_FLOAT_CASE(MAXSI, GLSMaxOp); + INT_OR_FLOAT_CASE(MAXF, SPVFMaxOp); + INT_OR_FLOAT_CASE(MINF, SPVFMinOp); + INT_OR_FLOAT_CASE(MINUI, SPVUMinOp); + INT_OR_FLOAT_CASE(MINSI, SPVSMinOp); + INT_OR_FLOAT_CASE(MAXUI, SPVUMaxOp); + INT_OR_FLOAT_CASE(MAXSI, SPVSMaxOp); case vector::CombiningKind::AND: case vector::CombiningKind::OR: @@ -403,15 +405,23 @@ }; } // namespace +#define CL_MAX_MIN_OPS \ + spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ + spirv::CLSMaxOp, spirv::CLSMinOp + +#define GL_MAX_MIN_OPS \ + spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \ + spirv::GLSMaxOp, spirv::GLSMinOp void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern>(typeConverter, patterns.getContext()); + patterns.add< + VectorBitcastConvert, VectorBroadcastConvert, + VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorReductionPattern, VectorInsertStridedSliceOpConvert, + VectorShuffleOpConvert, VectorSplatPattern>(typeConverter, + patterns.getContext()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,6 +24,7 @@ Promotion.cpp Split.cpp SplitReduction.cpp + SwapExtractSliceWithFillPatterns.cpp Tiling.cpp TilingInterfaceImpl.cpp Transforms.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp @@ -0,0 +1,41 @@ +//===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, +/// tensor.extract_slice(%init)) when the linalg.fill op have no other users. +/// This helps to reduce the fill footprint. +struct SwapExtractSliceOfFill final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto fillOp = extractOp.getSource().getDefiningOp(); + if (!fillOp || !fillOp->hasOneUse()) + return failure(); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(extractOp, fillOp.getInputs(), + ValueRange{newExtractOp.getResult()}); + return success(); + } +}; + +void mlir::linalg::populateSwapExtractSliceWithFillPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1525,6 +1525,90 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.PtrCastToGenericOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::PtrCastToGenericOp::verify() { + auto operandType = pointer().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Workgroup && + operandStorage != spirv::StorageClass::CrossWorkgroup && + operandStorage != spirv::StorageClass::Function) + return emitError("pointer must point to the Workgroup, CrossWorkgroup" + ", or Function Storage Class"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Generic) + return emitError("result type must be of storage class Generic"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrOp::verify() { + auto operandType = pointer().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicitOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { + auto operandType = pointer().getType().cast(); + auto resultType = result().getType().cast(); + + spirv::StorageClass operandStorage = operandType.getStorageClass(); + if (operandStorage != spirv::StorageClass::Generic) + return emitError("pointer type must be of storage class Generic"); + + spirv::StorageClass resultStorage = resultType.getStorageClass(); + if (resultStorage != spirv::StorageClass::Workgroup && + resultStorage != spirv::StorageClass::CrossWorkgroup && + resultStorage != spirv::StorageClass::Function) + return emitError("result must point to the Workgroup, CrossWorkgroup, " + "or Function Storage Class"); + + Type operandPointeeType = operandType.getPointeeType(); + Type resultPointeeType = resultType.getPointeeType(); + if (operandPointeeType != resultPointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << operandPointeeType << " vs " << resultPointeeType; + return success(); +} + //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -523,6 +523,38 @@ return success(); } +template <> +LogicalResult Deserializer::processOp( + ArrayRef words) { + if (words.size() != 4) { + return emitError(unknownLoc, + "expected 4 words in GenericCastToPtrExplicitOp" + " but got : ") + << words.size(); + } + SmallVector resultTypes; + SmallVector operands; + uint32_t valueID = 0; + auto type = getType(words[0]); + + if (!type) + return emitError(unknownLoc, "unknown type result : ") << words[0]; + resultTypes.push_back(type); + + valueID = words[1]; + + auto arg = getValue(words[2]); + if (!arg) + return emitError(unknownLoc, "unknown result : ") << words[2]; + operands.push_back(arg); + + Location loc = createFileLineColLoc(opBuilder); + Operation *op = opBuilder.create( + loc, resultTypes, operands); + valueMap[valueID] = op->getResult(0); + return success(); +} + // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -667,6 +667,32 @@ return success(); } +template <> +LogicalResult Serializer::processOp( + spirv::GenericCastToPtrExplicitOp op) { + SmallVector operands; + Type resultTy; + Location loc = op->getLoc(); + uint32_t resultTypeID = 0; + uint32_t resultID = 0; + resultTy = op->getResult(0).getType(); + if (failed(processType(loc, resultTy, resultTypeID))) + return failure(); + operands.push_back(resultTypeID); + + resultID = getNextID(); + operands.push_back(resultID); + valueIDMap[op->getResult(0)] = resultID; + + for (Value operand : op->getOperands()) + operands.push_back(getValueID(operand)); + spirv::StorageClass resultStorage = + resultTy.cast().getStorageClass(); + operands.push_back(static_cast(resultStorage)); + encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, + operands); + return success(); +} // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -114,3 +114,56 @@ %0 = "dialect.memref_producer"() : () -> (memref) return } + +// ----- + +/// Checks memory maps to OpenCL mapping if Kernel capability is enabled. +module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { +func.func @operand_result() { + // CHECK: memref> + %0 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<4xi32, #spv.storage_class> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) + // CHECK: memref> + %2 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<*xf16, #spv.storage_class> + %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) + + + "dialect.memref_consumer"(%0) : (memref) -> () + // CHECK: memref<4xi32, #spv.storage_class> + "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () + // CHECK: memref> + "dialect.memref_consumer"(%2) : (memref) -> () + // CHECK: memref<*xf16, #spv.storage_class> + "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () + + return +} +} + +// ----- + +/// Checks memory maps to Vulkan mapping if Shader capability is enabled. +module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { +func.func @operand_result() { + // CHECK: memref> + %0 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<4xi32, #spv.storage_class> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) + // CHECK: memref> + %2 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<*xf16, #spv.storage_class> + %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) + + + "dialect.memref_consumer"(%0) : (memref) -> () + // CHECK: memref<4xi32, #spv.storage_class> + "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () + // CHECK: memref> + "dialect.memref_consumer"(%2) : (memref) -> () + // CHECK: memref<*xf16, #spv.storage_class> + "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () + return +} +} \ No newline at end of file diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -33,6 +33,90 @@ return %0 : vector<1xf32> } +// CHECK-LABEL: func @cl_reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spv.CL.fmax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.fmax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.fmax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spv.CL.fmin %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.fmin %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.fmin %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_maxsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.CL.s_max %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.s_max %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.s_max %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_minsi +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.CL.s_min %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.s_min %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.s_min %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minsi(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_maxui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MAX0:.+]] = spv.CL.u_max %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spv.CL.u_max %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spv.CL.u_max %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + +// CHECK-LABEL: func @cl_reduction_minui +// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xi32> +// CHECK: %[[MIN0:.+]] = spv.CL.u_min %[[S0]], %[[S1]] +// CHECK: %[[MIN1:.+]] = spv.CL.u_min %[[MIN0]], %[[S2]] +// CHECK: %[[MIN2:.+]] = spv.CL.u_min %[[MIN1]], %[[S]] +// CHECK: return %[[MIN2]] +func.func @cl_reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { + %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 + return %reduce : i32 +} + } // end module // ----- diff --git a/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/swap-extract-slice-with-fill.mlir @@ -0,0 +1,28 @@ +//RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-swap-extract-slice-with-fill-pattern %s | FileCheck %s + +// CHECK-LABEL: func.func @swap_fill_insert_slice +// CHECK-SAME: (%[[INIT:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index) +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EXT:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], 8, 4] [1, %[[SIZE1]], 6] [1, 3, 1] +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EXT]] : tensor) -> tensor +// CHECK: return %[[FILL]] +func.func @swap_fill_insert_slice(%init : tensor, %offset0: index, %size1: index) -> tensor { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [1, %size1, 6] [1, 3, 1] + : tensor to tensor + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func.func @dont_swap_fill_insert_slice_multi_user +// CHECK: linalg.fill +// CHECK: tensor.extract_slice +func.func @dont_swap_fill_insert_slice_multi_user(%init : tensor, %offset0: index, %size1: index) -> (tensor, tensor<2x?x6xf32>) { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [2, %size1, 6] [1, 3, 1] + : tensor to tensor<2x?x6xf32> + return %0, %1: tensor, tensor<2x?x6xf32> +} diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -260,3 +260,106 @@ spv.ReturnValue %0 : i64 } +// ----- + +//===----------------------------------------------------------------------===// +// spv.PtrCastToGeneric +//===----------------------------------------------------------------------===// + +func.func @ptrcasttogeneric1(%arg0 : !spv.ptr) { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} +// ----- + +func.func @ptrcasttogeneric2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result type must be of storage class Generic}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @ptrcasttogeneric4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr, Generic> + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtr +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptr1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptr2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptr4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} +// ----- + +//===----------------------------------------------------------------------===// +// spv.GenericCastToPtrExplicit +//===----------------------------------------------------------------------===// + +func.func @genericcasttoptrexplicit1(%arg0 : !spv.ptr, Generic>) { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + return +} +// ----- + +func.func @genericcasttoptrexplicit2(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointer type must be of storage class Generic}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit3(%arg0 : !spv.ptr) { + // expected-error @+1 {{result must point to the Workgroup, CrossWorkgroup, or Function Storage Class}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr + return +} + +// ----- + +func.func @genericcasttoptrexplicit4(%arg0 : !spv.ptr) { + // expected-error @+1 {{pointee type must have the same as the op result type}} + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr to !spv.ptr, Workgroup> + return +} diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir --- a/mlir/test/Target/SPIRV/cast-ops.mlir +++ b/mlir/test/Target/SPIRV/cast-ops.mlir @@ -71,3 +71,23 @@ spv.ReturnValue %0 : i64 } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @ptr_cast_to_generic(%arg0 : !spv.ptr) "None" { + // CHECK: {{%.*}} = spv.PtrCastToGeneric {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.PtrCastToGeneric %arg0 : !spv.ptr to !spv.ptr + spv.Return + } + spv.func @generic_cast_to_ptr(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtr {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtr %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } + spv.func @generic_cast_to_ptr_explicit(%arg0 : !spv.ptr, Generic>) "None" { + // CHECK: {{%.*}} = spv.GenericCastToPtrExplicit {{%.*}} : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + %0 = spv.GenericCastToPtrExplicit %arg0 : !spv.ptr, Generic> to !spv.ptr, CrossWorkgroup> + spv.Return + } +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,11 @@ llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " "extract_slice + linalgOp"), llvm::cl::init(false)}; + Option testSwapExtractSliceWithFill{ + *this, "test-swap-extract-slice-with-fill-pattern", + llvm::cl::desc( + "Test patterns to swap tensor.extract_slice(linalg.fill())"), + llvm::cl::init(false)}; }; } // namespace @@ -508,6 +513,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateSwapExtractSliceWithFillPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { auto lambda = [&](void *) { @@ -551,6 +562,8 @@ return applySplitReduction(getOperation()); if (testBubbleUpExtractSliceOpPattern) return applyBubbleUpExtractSliceOpPattern(getOperation()); + if (testSwapExtractSliceWithFill) + return applySwapExtractSliceWithFillPattern(getOperation()); } namespace mlir {