diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -1585,7 +1585,7 @@ /// Extra information which affects how the function is called, like /// regparm and the calling convention. - unsigned ExtInfo : 13; + unsigned ExtInfo : 14; /// The ref-qualifier associated with a \c FunctionProtoType. /// @@ -1823,7 +1823,7 @@ Type(TypeClass tc, QualType canon, TypeDependence Dependence) : ExtQualsTypeCommonBase(this, canon.isNull() ? QualType(this_(), 0) : canon) { - static_assert(sizeof(*this) <= 8 + sizeof(ExtQualsTypeCommonBase), + static_assert(sizeof(*this) <= 16 + sizeof(ExtQualsTypeCommonBase), "changing bitfields changed sizeof(Type)!"); static_assert(alignof(decltype(*this)) % sizeof(void *) == 0, "Insufficient alignment!"); @@ -3663,6 +3663,8 @@ // | CC |noreturn|produces|nocallersavedregs|regparm|nocfcheck|cmsenscall| // |0 .. 4| 5 | 6 | 7 |8 .. 10| 11 | 12 | + // |amxpreserve| + // | 13 | // // regparm is either 0 (no regparm attribute) or the regparm value+1. enum { CallConvMask = 0x1F }; @@ -3675,6 +3677,7 @@ }; enum { NoCfCheckMask = 0x800 }; enum { CmseNSCallMask = 0x1000 }; + enum { AMXPreserveMask = 0x2000 }; uint16_t Bits = CC_C; ExtInfo(unsigned Bits) : Bits(static_cast(Bits)) {} @@ -3684,14 +3687,15 @@ // have all the elements (when reading an AST file for example). ExtInfo(bool noReturn, bool hasRegParm, unsigned regParm, CallingConv cc, bool producesResult, bool noCallerSavedRegs, bool NoCfCheck, - bool cmseNSCall) { + bool cmseNSCall, bool amxPreserve) { assert((!hasRegParm || regParm < 7) && "Invalid regparm value"); Bits = ((unsigned)cc) | (noReturn ? NoReturnMask : 0) | (producesResult ? ProducesResultMask : 0) | (noCallerSavedRegs ? NoCallerSavedRegsMask : 0) | (hasRegParm ? ((regParm + 1) << RegParmOffset) : 0) | (NoCfCheck ? NoCfCheckMask : 0) | - (cmseNSCall ? CmseNSCallMask : 0); + (cmseNSCall ? CmseNSCallMask : 0) | + (amxPreserve ? AMXPreserveMask : 0); } // Constructor with all defaults. Use when for example creating a @@ -3706,6 +3710,7 @@ bool getProducesResult() const { return Bits & ProducesResultMask; } bool getCmseNSCall() const { return Bits & CmseNSCallMask; } bool getNoCallerSavedRegs() const { return Bits & NoCallerSavedRegsMask; } + bool getAMXPreserve() const { return Bits & AMXPreserveMask; } bool getNoCfCheck() const { return Bits & NoCfCheckMask; } bool getHasRegParm() const { return ((Bits & RegParmMask) >> RegParmOffset) != 0; } @@ -3756,6 +3761,13 @@ return ExtInfo(Bits & ~NoCallerSavedRegsMask); } + ExtInfo withAMXPreserve(bool amxPreserve) const { + if (amxPreserve) + return ExtInfo(Bits | AMXPreserveMask); + else + return ExtInfo(Bits & ~AMXPreserveMask); + } + ExtInfo withNoCfCheck(bool noCfCheck) const { if (noCfCheck) return ExtInfo(Bits | NoCfCheckMask); diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -287,6 +287,9 @@ def : Property<"cmseNSCall", Bool> { let Read = [{ node->getExtInfo().getCmseNSCall() }]; } + def : Property<"amxPreserve", Bool> { + let Read = [{ node->getExtInfo().getAMXPreserve() }]; + } } let Class = FunctionNoProtoType in { @@ -294,7 +297,7 @@ auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm, callingConvention, producesResult, noCallerSavedRegs, noCfCheck, - cmseNSCall); + cmseNSCall, amxPreserve); return ctx.getFunctionNoProtoType(returnType, extInfo); }]>; } @@ -328,7 +331,7 @@ auto extInfo = FunctionType::ExtInfo(noReturn, hasRegParm, regParm, callingConvention, producesResult, noCallerSavedRegs, noCfCheck, - cmseNSCall); + cmseNSCall, amxPreserve); FunctionProtoType::ExtProtoInfo epi; epi.ExtInfo = extInfo; epi.Variadic = variadic; diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -2892,6 +2892,12 @@ let SimpleHandler = 1; } +def AMXPreserve : InheritableAttr, TargetSpecificAttr { + let Spellings = [Clang<"amxpreserve">]; + let Documentation = [AMXPreserveDocs]; + let SimpleHandler = 1; +} + def AnyX86Interrupt : InheritableAttr, TargetSpecificAttr { // NOTE: If you add any additional spellings, ARMInterrupt's, // M68kInterrupt's, MSP430Interrupt's and MipsInterrupt's spellings must match. diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -4457,6 +4457,39 @@ }]; } +def AMXPreserveDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +Use this attribute to indicate that the specified function has no +caller-saved AMX registers. Compiler doesn't save and restore any +AMX register across function call. It is user's responsibility that +ensure there is no AMX register clobber in the function with "amxpreserve" +attribute. + +Like 'no_caller_saved_registers', 'amxpreserve' attribute is not a +calling convention. In fact, it only overrides the decision of which +AMX registers should be saved by the caller. + +For example: + + .. code-block:: c + + __attribute__ ((amxpreserve )) + void f () { + ... + } + + void bar () { + ... + f(); + ... + } + + In this case compiler doesn't save and restore AMX registers across the + call of f(). + }]; +} + def X86ForceAlignArgPointerDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/CodeGen/CGFunctionInfo.h b/clang/include/clang/CodeGen/CGFunctionInfo.h --- a/clang/include/clang/CodeGen/CGFunctionInfo.h +++ b/clang/include/clang/CodeGen/CGFunctionInfo.h @@ -579,6 +579,9 @@ /// Whether this function saved caller registers. unsigned NoCallerSavedRegs : 1; + /// Whether this function preserve AMX state. + unsigned AMXPreserve : 1; + /// How many arguments to pass inreg. unsigned HasRegParm : 1; unsigned RegParm : 3; @@ -671,6 +674,9 @@ /// Whether this function no longer saves caller registers. bool isNoCallerSavedRegs() const { return NoCallerSavedRegs; } + /// Whether this function preserve AMX state. + bool isAMXPreserve() const { return AMXPreserve; } + /// Whether this function has nocf_check attribute. bool isNoCfCheck() const { return NoCfCheck; } @@ -700,7 +706,7 @@ return FunctionType::ExtInfo(isNoReturn(), getHasRegParm(), getRegParm(), getASTCallingConvention(), isReturnsRetained(), isNoCallerSavedRegs(), isNoCfCheck(), - isCmseNSCall()); + isCmseNSCall(), isAMXPreserve()); } CanQualType getReturnType() const { return getArgsBuffer()[0].type; } @@ -742,6 +748,7 @@ ID.AddInteger(RegParm); ID.AddBoolean(NoCfCheck); ID.AddBoolean(CmseNSCall); + ID.AddBoolean(AMXPreserve); ID.AddInteger(Required.getOpaqueData()); ID.AddBoolean(HasExtParameterInfos); if (HasExtParameterInfos) { @@ -770,6 +777,7 @@ ID.AddInteger(info.getRegParm()); ID.AddBoolean(info.getNoCfCheck()); ID.AddBoolean(info.getCmseNSCall()); + ID.AddBoolean(info.getAMXPreserve()); ID.AddInteger(required.getOpaqueData()); ID.AddBoolean(!paramInfos.empty()); if (!paramInfos.empty()) { diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -9632,6 +9632,8 @@ return {}; if (lbaseInfo.getNoCallerSavedRegs() != rbaseInfo.getNoCallerSavedRegs()) return {}; + if (lbaseInfo.getAMXPreserve() != rbaseInfo.getAMXPreserve()) + return {}; if (lbaseInfo.getNoCfCheck() != rbaseInfo.getNoCfCheck()) return {}; diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -617,6 +617,8 @@ return false; if (EI1.getNoCallerSavedRegs() != EI2.getNoCallerSavedRegs()) return false; + if (EI1.getAMXPreserve() != EI2.getAMXPreserve()) + return false; if (EI1.getNoCfCheck() != EI2.getNoCfCheck()) return false; diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -1000,6 +1000,8 @@ << Info.getRegParm() << ")))"; if (Info.getNoCallerSavedRegs()) OS << " __attribute__((no_caller_saved_registers))"; + if (Info.getAMXPreserve()) + OS << " __attribute__((amxpreserve))"; if (Info.getNoCfCheck()) OS << " __attribute__((nocf_check))"; } diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -823,6 +823,7 @@ FI->NoReturn = info.getNoReturn(); FI->ReturnsRetained = info.getProducesResult(); FI->NoCallerSavedRegs = info.getNoCallerSavedRegs(); + FI->AMXPreserve = info.getAMXPreserve(); FI->NoCfCheck = info.getNoCfCheck(); FI->Required = required; FI->HasRegParm = info.getHasRegParm(); @@ -2116,6 +2117,8 @@ FuncAttrs.addAttribute(llvm::Attribute::NoCfCheck); if (TargetDecl->hasAttr()) FuncAttrs.addAttribute(llvm::Attribute::NoCallback); + if (TargetDecl->hasAttr()) + FuncAttrs.addAttribute(llvm::Attribute::AMXPreserve); HasOptnone = TargetDecl->hasAttr(); if (auto *AllocSize = TargetDecl->getAttr()) { diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp --- a/clang/lib/CodeGen/CodeGenModule.cpp +++ b/clang/lib/CodeGen/CodeGenModule.cpp @@ -1858,6 +1858,8 @@ // carry an explicit noinline attribute. if (!F->hasFnAttribute(llvm::Attribute::AlwaysInline)) B.addAttribute(llvm::Attribute::NoInline); + } else if (D->hasAttr()) { + B.addAttribute(llvm::Attribute::AMXPreserve); } else { // Otherwise, propagate the inline hint attribute and potentially use its // absence to mark things as noinline. diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -3526,6 +3526,18 @@ RequiresAdjustment = true; } + if (OldTypeInfo.getAMXPreserve() != NewTypeInfo.getAMXPreserve()) { + if (NewTypeInfo.getAMXPreserve()) { + AMXPreserveAttr *Attr = New->getAttr(); + Diag(New->getLocation(), diag::err_function_attribute_mismatch) << Attr; + Diag(OldLocation, diag::note_previous_declaration); + return true; + } + + NewTypeInfo = NewTypeInfo.withAMXPreserve(true); + RequiresAdjustment = true; + } + if (RequiresAdjustment) { const FunctionType *AdjustedType = New->getType()->getAs(); AdjustedType = Context.adjustFunctionType(AdjustedType, NewTypeInfo); diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -135,6 +135,7 @@ case ParsedAttr::AT_CmseNSCall: \ case ParsedAttr::AT_AnyX86NoCallerSavedRegisters: \ case ParsedAttr::AT_AnyX86NoCfCheck: \ + case ParsedAttr::AT_AMXPreserve: \ CALLING_CONV_ATTRS_CASELIST // Microsoft-specific type qualifiers. @@ -7514,6 +7515,20 @@ return true; } + if (attr.getKind() == ParsedAttr::AT_AMXPreserve) { + if (S.CheckAttrTarget(attr) || S.CheckAttrNoArgs(attr)) + return true; + + // Delay if this is not a function type. + if (!unwrapped.isFunctionType()) + return false; + + FunctionType::ExtInfo EI = + unwrapped.get()->getExtInfo().withAMXPreserve(true); + type = unwrapped.wrap(S, S.Context.adjustFunctionType(unwrapped.get(), EI)); + return true; + } + if (attr.getKind() == ParsedAttr::AT_AnyX86NoCallerSavedRegisters) { if (S.CheckAttrTarget(attr) || S.CheckAttrNoArgs(attr)) return true; diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -595,6 +595,7 @@ Abv->Add(BitCodeAbbrevOp(0)); // NoCallerSavedRegs Abv->Add(BitCodeAbbrevOp(0)); // NoCfCheck Abv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 1)); // CmseNSCall + Abv->Add(BitCodeAbbrevOp(0)); // AMXPreserve // FunctionProtoType Abv->Add(BitCodeAbbrevOp(0)); // IsVariadic Abv->Add(BitCodeAbbrevOp(0)); // HasTrailingReturn diff --git a/clang/test/Sema/attr-target-mv.c b/clang/test/Sema/attr-target-mv.c --- a/clang/test/Sema/attr-target-mv.c +++ b/clang/test/Sema/attr-target-mv.c @@ -76,6 +76,11 @@ // expected-note@+1 {{function multiversioning caused by this declaration}} int __attribute__((target("arch=ivybridge"))) prev_no_target2(void); +void __attribute__((target("sse4.2"))) addtl_amx_attrs(void); +//expected-error@+2 {{attribute 'target' multiversioning cannot be combined with attribute 'amxpreserve'}} +void __attribute__((amxpreserve,target("arch=sandybridge"))) +addtl_amx_attrs(void); + void __attribute__((target("sse4.2"))) addtl_attrs(void); //expected-error@+2 {{attribute 'target' multiversioning cannot be combined with attribute 'no_caller_saved_registers'}} void __attribute__((no_caller_saved_registers,target("arch=sandybridge"))) diff --git a/clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp b/clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/attr-non-x86-amx-preserve.cpp @@ -0,0 +1,29 @@ +// RUN: %clang_cc1 -std=c++11 -triple armv7-unknown-linux-gnueabi -fsyntax-only -verify %s + +struct a { + int __attribute__((amxpreserve)) b; // expected-warning {{unknown attribute 'amxpreserve' ignored}} + static void foo(int *a) __attribute__((amxpreserve)) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}} +}; + +struct a test __attribute__((amxpreserve)); // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +__attribute__((amxpreserve(999))) void bar(int *) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +__attribute__((amxpreserve)) void foo(int *){} // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +[[clang::amxpreserve]] void foo2(int *) {} // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +typedef __attribute__((amxpreserve)) void (*foo3)(int *); // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +typedef void (*foo5)(int *); + +int (*foo4)(double a, __attribute__((amxpreserve)) float b); // expected-warning {{unknown attribute 'amxpreserve' ignored}} + +int main(int argc, char **argv) { + void (*fp)(int *) = foo; + a::foo(&argc); + foo3 func = foo2; + func(&argc); + foo5 __attribute__((amxpreserve)) func2 = foo2; // expected-warning {{unknown attribute 'amxpreserve' ignored}} + return 0; +} diff --git a/clang/test/SemaCXX/attr-x86-amx-preserve.cpp b/clang/test/SemaCXX/attr-x86-amx-preserve.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/attr-x86-amx-preserve.cpp @@ -0,0 +1,33 @@ +// RUN: %clang_cc1 -std=c++11 -triple x86_64-unknown-linux-gnu -fsyntax-only -verify %s + +struct a { + int b __attribute__((amxpreserve)); // expected-warning {{'amxpreserve' only applies to function types; type here is 'int'}} + static void foo(int *a) __attribute__((amxpreserve)) {} +}; + +struct a test __attribute__((amxpreserve)); // expected-warning {{'amxpreserve' only applies to function types; type here is 'struct a'}} + +__attribute__((amxpreserve(999))) void bar(int *) {} // expected-error {{'amxpreserve' attribute takes no arguments}} + +void __attribute__((amxpreserve)) foo(int *){} + +__attribute__((amxpreserve)) void foo2(int *) {} + +typedef __attribute__((amxpreserve)) void (*foo3)(int *); + +int (*foo4)(double a, __attribute__((amxpreserve)) float b); // expected-warning {{'amxpreserve' only applies to function types; type here is 'float'}} + +typedef void (*foo5)(int *); + +void foo6(){} // expected-note {{previous declaration is here}} + +void __attribute__((amxpreserve)) foo6(); // expected-error {{function declared with 'amxpreserve' attribute was previously declared without the 'amxpreserve' attribute}} + +int main(int argc, char **argv) { + void (*fp)(int *) = foo; // expected-error {{cannot initialize a variable of type 'void (*)(int *)' with an lvalue of type 'void (int *) __attribute__((amxpreserve))'}} + a::foo(&argc); + foo3 func = foo2; + func(&argc); + foo5 __attribute__((amxpreserve)) func2 = foo2; + return 0; +} diff --git a/clang/unittests/AST/StructuralEquivalenceTest.cpp b/clang/unittests/AST/StructuralEquivalenceTest.cpp --- a/clang/unittests/AST/StructuralEquivalenceTest.cpp +++ b/clang/unittests/AST/StructuralEquivalenceTest.cpp @@ -476,6 +476,16 @@ EXPECT_FALSE(testStructuralMatch(t)); } +TEST_F(StructuralEquivalenceFunctionTest, + FunctionsWithDifferentAMXSavedRegsAttr) { + if (llvm::Triple(llvm::sys::getDefaultTargetTriple()).getArch() != + llvm::Triple::x86_64) + return; + auto t = makeNamedDecls("__attribute__((amxpreserve)) void foo();", + " void foo();", Lang_C99); + EXPECT_FALSE(testStructuralMatch(t)); +} + struct StructuralEquivalenceCXXMethodTest : StructuralEquivalenceTest { };