diff --git a/llvm/lib/Target/X86/X86CallingConv.td b/llvm/lib/Target/X86/X86CallingConv.td --- a/llvm/lib/Target/X86/X86CallingConv.td +++ b/llvm/lib/Target/X86/X86CallingConv.td @@ -1209,6 +1209,9 @@ def CSR_64_Intel_OCL_BI_AVX512 : CalleeSavedRegs<(add RBX, RSI, R14, R15, (sequence "ZMM%u", 16, 31), K4, K5, K6, K7)>; +// Save all AMX registers. This are additional callee saved registers that is +// specified by function attribute amx_state_preserve. +def CSR_64_AMX_Ext : CalleeSavedRegs<(add (sequence "TMM%u", 0, 7))>; // Only R12 is preserved for PHP calls in HHVM. def CSR_64_HHVM : CalleeSavedRegs<(add R12)>; diff --git a/llvm/lib/Target/X86/X86FastISel.cpp b/llvm/lib/Target/X86/X86FastISel.cpp --- a/llvm/lib/Target/X86/X86FastISel.cpp +++ b/llvm/lib/Target/X86/X86FastISel.cpp @@ -3190,6 +3190,10 @@ if ((CB && CB->hasFnAttr("no_callee_saved_registers"))) return false; + // Functions with amxpreserve that need special handling. + if ((CB && CB->hasFnAttr("amxpreserve"))) + return false; + // Functions using thunks for indirect calls need to use SDISel. if (Subtarget->useIndirectThunkCalls()) return false; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -4595,7 +4595,18 @@ // to use the CSR_NoRegs_RegMask. if (CB && CB->hasFnAttr("no_callee_saved_registers")) AdaptedCC = (CallingConv::ID)CallingConv::GHC; - return RegInfo->getCallPreservedMask(MF, AdaptedCC); + const uint32_t *StaticMask = RegInfo->getCallPreservedMask(MF, AdaptedCC); + if (!CB || !CB->hasFnAttr("amxpreserve")) + return StaticMask; + // Allocate a new Reg Mask and copy Mask. + uint32_t *DynamicMask = MF.allocateRegMask(); + const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo(); + unsigned RegMaskSize = MachineOperand::getRegMaskSize(TRI->getNumRegs()); + memcpy(DynamicMask, StaticMask, sizeof(DynamicMask[0]) * RegMaskSize); + const uint32_t *AMXPreservedMask = RegInfo->getAMXPreservedMask(); + for (unsigned I = 0; I < RegMaskSize; ++I) + DynamicMask[I] |= AMXPreservedMask[I]; + return const_cast(DynamicMask); }(); assert(Mask && "Missing call preserved mask for calling convention"); diff --git a/llvm/lib/Target/X86/X86RegisterInfo.h b/llvm/lib/Target/X86/X86RegisterInfo.h --- a/llvm/lib/Target/X86/X86RegisterInfo.h +++ b/llvm/lib/Target/X86/X86RegisterInfo.h @@ -48,6 +48,12 @@ /// variable size stack objects. unsigned BasePtr; + /// Map the callee saved register list to the same list with AMX register + /// saved. We need memory to save the extension list, so we save it in + /// X86RegisterInfo. + mutable DenseMap> + CalleeSavedRegsExt; + public: explicit X86RegisterInfo(const Triple &TT); @@ -95,6 +101,9 @@ unsigned getRegPressureLimit(const TargetRegisterClass *RC, MachineFunction &MF) const override; + std::pair + getStaticCalleeSavedRegs(const MachineFunction *MF) const; + /// getCalleeSavedRegs - Return a null-terminated list of all of the /// callee-save registers on this target. const MCPhysReg * @@ -104,6 +113,7 @@ const uint32_t *getCallPreservedMask(const MachineFunction &MF, CallingConv::ID) const override; const uint32_t *getNoPreservedMask() const override; + const uint32_t *getAMXPreservedMask() const; // Calls involved in thread-local variable lookup save more registers than // normal calls, so they need a different mask to represent this. diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -272,8 +272,8 @@ } } -const MCPhysReg * -X86RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { +std::pair +X86RegisterInfo::getStaticCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "MachineFunction required"); const X86Subtarget &Subtarget = MF->getSubtarget(); @@ -285,6 +285,8 @@ CallingConv::ID CC = F.getCallingConv(); +#define CSR_ARRAY(list) std::make_pair(list, sizeof(list)) + // If attribute NoCallerSavedRegisters exists then we set X86_INTR calling // convention because it has the CSR list. if (MF->getFunction().hasFnAttribute("no_caller_saved_registers")) @@ -293,92 +295,94 @@ // If atribute specified, override the CSRs normally specified by the // calling convention and use the empty set instead. if (MF->getFunction().hasFnAttribute("no_callee_saved_registers")) - return CSR_NoRegs_SaveList; + return CSR_ARRAY(CSR_NoRegs_SaveList); switch (CC) { case CallingConv::GHC: case CallingConv::HiPE: - return CSR_NoRegs_SaveList; + return CSR_ARRAY(CSR_NoRegs_SaveList); case CallingConv::AnyReg: if (HasAVX) - return CSR_64_AllRegs_AVX_SaveList; - return CSR_64_AllRegs_SaveList; + return CSR_ARRAY(CSR_64_AllRegs_AVX_SaveList); + return CSR_ARRAY(CSR_64_AllRegs_SaveList); case CallingConv::PreserveMost: - return CSR_64_RT_MostRegs_SaveList; + return CSR_ARRAY(CSR_64_RT_MostRegs_SaveList); case CallingConv::PreserveAll: if (HasAVX) - return CSR_64_RT_AllRegs_AVX_SaveList; - return CSR_64_RT_AllRegs_SaveList; + return CSR_ARRAY(CSR_64_RT_AllRegs_AVX_SaveList); + return CSR_ARRAY(CSR_64_RT_AllRegs_SaveList); case CallingConv::CXX_FAST_TLS: if (Is64Bit) - return MF->getInfo()->isSplitCSR() ? - CSR_64_CXX_TLS_Darwin_PE_SaveList : CSR_64_TLS_Darwin_SaveList; + return MF->getInfo()->isSplitCSR() + ? CSR_ARRAY(CSR_64_CXX_TLS_Darwin_PE_SaveList) + : CSR_ARRAY(CSR_64_TLS_Darwin_SaveList); break; case CallingConv::Intel_OCL_BI: { if (HasAVX512 && IsWin64) - return CSR_Win64_Intel_OCL_BI_AVX512_SaveList; + return CSR_ARRAY(CSR_Win64_Intel_OCL_BI_AVX512_SaveList); if (HasAVX512 && Is64Bit) - return CSR_64_Intel_OCL_BI_AVX512_SaveList; + return CSR_ARRAY(CSR_64_Intel_OCL_BI_AVX512_SaveList); if (HasAVX && IsWin64) - return CSR_Win64_Intel_OCL_BI_AVX_SaveList; + return CSR_ARRAY(CSR_Win64_Intel_OCL_BI_AVX_SaveList); if (HasAVX && Is64Bit) - return CSR_64_Intel_OCL_BI_AVX_SaveList; + return CSR_ARRAY(CSR_64_Intel_OCL_BI_AVX_SaveList); if (!HasAVX && !IsWin64 && Is64Bit) - return CSR_64_Intel_OCL_BI_SaveList; + return CSR_ARRAY(CSR_64_Intel_OCL_BI_SaveList); break; } case CallingConv::HHVM: - return CSR_64_HHVM_SaveList; + return CSR_ARRAY(CSR_64_HHVM_SaveList); case CallingConv::X86_RegCall: if (Is64Bit) { if (IsWin64) { - return (HasSSE ? CSR_Win64_RegCall_SaveList : - CSR_Win64_RegCall_NoSSE_SaveList); + return (HasSSE ? CSR_ARRAY(CSR_Win64_RegCall_SaveList) + : CSR_ARRAY(CSR_Win64_RegCall_NoSSE_SaveList)); } else { - return (HasSSE ? CSR_SysV64_RegCall_SaveList : - CSR_SysV64_RegCall_NoSSE_SaveList); + return (HasSSE ? CSR_ARRAY(CSR_SysV64_RegCall_SaveList) + : CSR_ARRAY(CSR_SysV64_RegCall_NoSSE_SaveList)); } } else { - return (HasSSE ? CSR_32_RegCall_SaveList : - CSR_32_RegCall_NoSSE_SaveList); + return (HasSSE ? CSR_ARRAY(CSR_32_RegCall_SaveList) + : CSR_ARRAY(CSR_32_RegCall_NoSSE_SaveList)); } case CallingConv::CFGuard_Check: assert(!Is64Bit && "CFGuard check mechanism only used on 32-bit X86"); - return (HasSSE ? CSR_Win32_CFGuard_Check_SaveList - : CSR_Win32_CFGuard_Check_NoSSE_SaveList); + return (HasSSE ? CSR_ARRAY(CSR_Win32_CFGuard_Check_SaveList) + : CSR_ARRAY(CSR_Win32_CFGuard_Check_NoSSE_SaveList)); case CallingConv::Cold: if (Is64Bit) - return CSR_64_MostRegs_SaveList; + return CSR_ARRAY(CSR_64_MostRegs_SaveList); break; case CallingConv::Win64: if (!HasSSE) - return CSR_Win64_NoSSE_SaveList; - return CSR_Win64_SaveList; + return CSR_ARRAY(CSR_Win64_NoSSE_SaveList); + return CSR_ARRAY(CSR_Win64_SaveList); case CallingConv::SwiftTail: if (!Is64Bit) - return CSR_32_SaveList; - return IsWin64 ? CSR_Win64_SwiftTail_SaveList : CSR_64_SwiftTail_SaveList; + return CSR_ARRAY(CSR_32_SaveList); + return IsWin64 ? CSR_ARRAY(CSR_Win64_SwiftTail_SaveList) + : CSR_ARRAY(CSR_64_SwiftTail_SaveList); case CallingConv::X86_64_SysV: if (CallsEHReturn) - return CSR_64EHRet_SaveList; - return CSR_64_SaveList; + return CSR_ARRAY(CSR_64EHRet_SaveList); + return CSR_ARRAY(CSR_64_SaveList); case CallingConv::X86_INTR: if (Is64Bit) { if (HasAVX512) - return CSR_64_AllRegs_AVX512_SaveList; + return CSR_ARRAY(CSR_64_AllRegs_AVX512_SaveList); if (HasAVX) - return CSR_64_AllRegs_AVX_SaveList; + return CSR_ARRAY(CSR_64_AllRegs_AVX_SaveList); if (HasSSE) - return CSR_64_AllRegs_SaveList; - return CSR_64_AllRegs_NoSSE_SaveList; + return CSR_ARRAY(CSR_64_AllRegs_SaveList); + return CSR_ARRAY(CSR_64_AllRegs_NoSSE_SaveList); } else { if (HasAVX512) - return CSR_32_AllRegs_AVX512_SaveList; + return CSR_ARRAY(CSR_32_AllRegs_AVX512_SaveList); if (HasAVX) - return CSR_32_AllRegs_AVX_SaveList; + return CSR_ARRAY(CSR_32_AllRegs_AVX_SaveList); if (HasSSE) - return CSR_32_AllRegs_SSE_SaveList; - return CSR_32_AllRegs_SaveList; + return CSR_ARRAY(CSR_32_AllRegs_SSE_SaveList); + return CSR_ARRAY(CSR_32_AllRegs_SaveList); } default: break; @@ -388,17 +392,40 @@ bool IsSwiftCC = Subtarget.getTargetLowering()->supportSwiftError() && F.getAttributes().hasAttrSomewhere(Attribute::SwiftError); if (IsSwiftCC) - return IsWin64 ? CSR_Win64_SwiftError_SaveList - : CSR_64_SwiftError_SaveList; + return IsWin64 ? CSR_ARRAY(CSR_Win64_SwiftError_SaveList) + : CSR_ARRAY(CSR_64_SwiftError_SaveList); if (IsWin64) - return HasSSE ? CSR_Win64_SaveList : CSR_Win64_NoSSE_SaveList; + return HasSSE ? CSR_ARRAY(CSR_Win64_SaveList) + : CSR_ARRAY(CSR_Win64_NoSSE_SaveList); if (CallsEHReturn) - return CSR_64EHRet_SaveList; - return CSR_64_SaveList; + return CSR_ARRAY(CSR_64EHRet_SaveList); + return CSR_ARRAY(CSR_64_SaveList); } - return CallsEHReturn ? CSR_32EHRet_SaveList : CSR_32_SaveList; + return CallsEHReturn ? CSR_ARRAY(CSR_32EHRet_SaveList) + : CSR_ARRAY(CSR_32_SaveList); +#undef CSR_ARRAY +} + +const MCPhysReg * +X86RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { + const MCPhysReg *StaticCSR; + int Size; + std::tie(StaticCSR, Size) = getStaticCalleeSavedRegs(MF); + + if (CalleeSavedRegsExt.count(StaticCSR)) + return CalleeSavedRegsExt[StaticCSR].get(); + + // -1 because there are duplicated 0 at the end of the list. + MCPhysReg *CSR = new MCPhysReg[Size + sizeof(CSR_64_AMX_Ext_SaveList) - 1]; + std::copy(StaticCSR, StaticCSR + Size - 1, CSR); + std::copy(CSR_64_AMX_Ext_SaveList, + CSR_64_AMX_Ext_SaveList + sizeof(CSR_64_AMX_Ext_SaveList), + CSR + Size - 1); + CalleeSavedRegsExt[StaticCSR].reset(CSR); + + return CSR; } const MCPhysReg *X86RegisterInfo::getCalleeSavedRegsViaCopy( @@ -522,6 +549,10 @@ return CSR_NoRegs_RegMask; } +const uint32_t *X86RegisterInfo::getAMXPreservedMask() const { + return CSR_64_AMX_Ext_RegMask; +} + const uint32_t *X86RegisterInfo::getDarwinTLSCallPreservedMask() const { return CSR_64_TLS_Darwin_RegMask; } diff --git a/llvm/test/CodeGen/X86/AMX/amx-preserve-cc.ll b/llvm/test/CodeGen/X86/AMX/amx-preserve-cc.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-preserve-cc.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -mattr=+avx512f -verify-machineinstrs | FileCheck %s + +@buf = dso_local global [3072 x i8] zeroinitializer, align 64 + +define void @foo() "amxpreserve" { +; CHECK-LABEL: foo: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: retq +entry: + ret void +} +declare void @external() + +define dso_local void @test_api(i16 signext %0, i16 signext %1) nounwind { +; CHECK-LABEL: test_api: +; CHECK: # %bb.0: +; CHECK-NEXT: pushq %rbp +; CHECK-NEXT: pushq %r15 +; CHECK-NEXT: pushq %r14 +; CHECK-NEXT: pushq %r12 +; CHECK-NEXT: pushq %rbx +; CHECK-NEXT: subq $64, %rsp +; CHECK-NEXT: movl %esi, %ebx +; CHECK-NEXT: movl %edi, %ebp +; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 +; CHECK-NEXT: vmovdqu64 %zmm0, (%rsp) +; CHECK-NEXT: movb $1, (%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %bx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %bpl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %bx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %bpl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg (%rsp) +; CHECK-NEXT: movl $buf, %eax +; CHECK-NEXT: movl $32, %r14d +; CHECK-NEXT: movw $8, %r15w +; CHECK-NEXT: tileloadd (%rax,%r14), %tmm0 +; CHECK-NEXT: movl $buf+1024, %eax +; CHECK-NEXT: tileloadd (%rax,%r14), %tmm1 +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: callq foo@PLT +; CHECK-NEXT: movl $buf+2048, %r12d +; CHECK-NEXT: tileloadd (%r12,%r14), %tmm2 +; CHECK-NEXT: callq external@PLT +; CHECK-NEXT: tdpbssd %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: tilestored %tmm2, (%r12,%r14) +; CHECK-NEXT: addq $64, %rsp +; CHECK-NEXT: popq %rbx +; CHECK-NEXT: popq %r12 +; CHECK-NEXT: popq %r14 +; CHECK-NEXT: popq %r15 +; CHECK-NEXT: popq %rbp +; CHECK-NEXT: tilerelease +; CHECK-NEXT: retq + %3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) + %4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) + ; call void @foo() "amxpreserve" + call void @foo() + %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) + call void @external() "amxpreserve" + %6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, x86_amx %5, x86_amx %3, x86_amx %4) + tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, x86_amx %6) + ret void +} + +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)