diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -865,7 +865,7 @@ if (MF->getFunction().getCallingConv() == CallingConv::AArch64_VectorCall || MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall || - STI->getRegisterInfo()->hasSVEArgsOrReturn(MF)) { + MF->getInfo()->isSVECC()) { auto *TS = static_cast(OutStreamer->getTargetStreamer()); TS->emitDirectiveVariantPCS(CurrentFnSym); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5434,8 +5434,16 @@ const SmallVectorImpl &Ins, const SDLoc &DL, SelectionDAG &DAG, SmallVectorImpl &InVals) const { MachineFunction &MF = DAG.getMachineFunction(); + const Function &F = MF.getFunction(); MachineFrameInfo &MFI = MF.getFrameInfo(); - bool IsWin64 = Subtarget->isCallingConvWin64(MF.getFunction().getCallingConv()); + bool IsWin64 = Subtarget->isCallingConvWin64(F.getCallingConv()); + AArch64FunctionInfo *FuncInfo = MF.getInfo(); + + SmallVector Outs; + GetReturnInfo(CallConv, F.getReturnType(), F.getAttributes(), Outs, + DAG.getTargetLoweringInfo(), MF.getDataLayout()); + if (any_of(Outs, [](ISD::OutputArg &Out){ return Out.VT.isScalableVector(); })) + FuncInfo->setIsSVECC(true); // Assign locations to all of the incoming arguments. SmallVector ArgLocs; @@ -5449,7 +5457,7 @@ // we use a special version of AnalyzeFormalArguments to pass in ValVT and // LocVT. unsigned NumArgs = Ins.size(); - Function::const_arg_iterator CurOrigArg = MF.getFunction().arg_begin(); + Function::const_arg_iterator CurOrigArg = F.arg_begin(); unsigned CurArgIdx = 0; for (unsigned i = 0; i != NumArgs; ++i) { MVT ValVT = Ins[i].VT; @@ -5520,11 +5528,13 @@ else if (RegVT == MVT::f128 || RegVT.is128BitVector()) RC = &AArch64::FPR128RegClass; else if (RegVT.isScalableVector() && - RegVT.getVectorElementType() == MVT::i1) + RegVT.getVectorElementType() == MVT::i1) { + FuncInfo->setIsSVECC(true); RC = &AArch64::PPRRegClass; - else if (RegVT.isScalableVector()) + } else if (RegVT.isScalableVector()) { + FuncInfo->setIsSVECC(true); RC = &AArch64::ZPRRegClass; - else + } else llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering"); // Transform the arguments in physical registers into virtual ones. @@ -5646,7 +5656,7 @@ // i1 arguments are zero-extended to i8 by the caller. Emit a // hint to reflect this. if (Ins[i].isOrigArg()) { - Argument *OrigArg = MF.getFunction().getArg(Ins[i].getOrigArgIndex()); + Argument *OrigArg = F.getArg(Ins[i].getOrigArgIndex()); if (OrigArg->getType()->isIntegerTy(1)) { if (!Ins[i].Flags.isZExt()) { ArgValue = DAG.getNode(AArch64ISD::ASSERT_ZEXT_BOOL, DL, @@ -5661,7 +5671,6 @@ assert((ArgLocs.size() + ExtraArgLocs) == Ins.size()); // varargs - AArch64FunctionInfo *FuncInfo = MF.getInfo(); if (isVarArg) { if (!Subtarget->isTargetDarwin() || IsWin64) { // The AAPCS variadic function ABI is identical to the non-variadic @@ -5974,7 +5983,7 @@ // The check for matching callee-saved regs will determine whether it is // eligible for TCO. if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) && - AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) + MF.getInfo()->isSVECC()) CallerCC = CallingConv::AArch64_SVE_VectorCall; bool CCMatch = CallerCC == CalleeCC; diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -177,6 +177,10 @@ bool IsMTETagged = false; + /// The function has Scalable Vector or Scalable Predicate register argument + /// or return type + bool IsSVECC = false; + /// True if the function need unwind information. mutable Optional NeedsDwarfUnwindInfo; @@ -191,6 +195,9 @@ const DenseMap &Src2DstMBB) const override; + bool isSVECC() const { return IsSVECC; }; + void setIsSVECC(bool s) { IsSVECC = s; }; + void initializeBaseYamlFields(const yaml::AArch64FunctionInfo &YamlMFI); unsigned getBytesInStackArgArea() const { return BytesInStackArgArea; } diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h @@ -42,8 +42,6 @@ void UpdateCustomCallPreservedMask(MachineFunction &MF, const uint32_t **Mask) const; - static bool hasSVEArgsOrReturn(const MachineFunction *MF); - /// Code Generation virtual methods... const MCPhysReg *getCalleeSavedRegs(const MachineFunction *MF) const override; const MCPhysReg *getDarwinCalleeSavedRegs(const MachineFunction *MF) const; diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -66,14 +66,6 @@ return true; } -bool AArch64RegisterInfo::hasSVEArgsOrReturn(const MachineFunction *MF) { - const Function &F = MF->getFunction(); - return isa(F.getReturnType()) || - any_of(F.args(), [](const Argument &Arg) { - return isa(Arg.getType()); - }); -} - const MCPhysReg * AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "Invalid MachineFunction pointer."); @@ -111,7 +103,7 @@ // This is for OSes other than Windows; Windows is a separate case further // above. return CSR_AArch64_AAPCS_X18_SaveList; - if (hasSVEArgsOrReturn(MF)) + if (MF->getInfo()->isSVECC()) return CSR_AArch64_SVE_AAPCS_SaveList; return CSR_AArch64_AAPCS_SaveList; } diff --git a/llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll b/llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll --- a/llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll +++ b/llvm/test/CodeGen/AArch64/sve-calling-convention-mixed.ll @@ -188,6 +188,193 @@ ret double %x0 } +; Use AAVPCS, SVE register in z0-z7 used + +define void @aavpcs1(i32 %s0, i32 %s1, i32 %s2, i32 %s3, i32 %s4, i32 %s5, i32 %s6, %s7, %s8, %s9, %s10, %s11, %s12, %s13, %s14, %s15, %s16, i32 * %ptr) nounwind { +; CHECK-LABEL: aavpcs1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp x8, x9, [sp] +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z3.s }, p0/z, [x8] +; CHECK-NEXT: ld1w { z24.s }, p0/z, [x7] +; CHECK-NEXT: st1w { z0.s }, p0, [x9] +; CHECK-NEXT: st1w { z1.s }, p0, [x9] +; CHECK-NEXT: st1w { z2.s }, p0, [x9] +; CHECK-NEXT: st1w { z4.s }, p0, [x9] +; CHECK-NEXT: st1w { z5.s }, p0, [x9] +; CHECK-NEXT: st1w { z6.s }, p0, [x9] +; CHECK-NEXT: st1w { z7.s }, p0, [x9] +; CHECK-NEXT: st1w { z24.s }, p0, [x9] +; CHECK-NEXT: st1w { z3.s }, p0, [x9] +; CHECK-NEXT: ret +entry: + %ptr1.bc = bitcast i32 * %ptr to * + store volatile %s7, * %ptr1.bc + store volatile %s8, * %ptr1.bc + store volatile %s9, * %ptr1.bc + store volatile %s11, * %ptr1.bc + store volatile %s12, * %ptr1.bc + store volatile %s13, * %ptr1.bc + store volatile %s14, * %ptr1.bc + store volatile %s15, * %ptr1.bc + store volatile %s16, * %ptr1.bc + ret void +} + +; Use AAVPCS, SVE register in z0-z7 used + +define void @aavpcs2(float %s0, float %s1, float %s2, float %s3, float %s4, float %s5, float %s6, %s7, %s8, %s9, %s10, %s11, %s12, %s13, %s14, %s15, %s16,float * %ptr) nounwind { +; CHECK-LABEL: aavpcs2: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp x8, x9, [sp] +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x7] +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x6] +; CHECK-NEXT: ld1w { z3.s }, p0/z, [x5] +; CHECK-NEXT: ld1w { z4.s }, p0/z, [x4] +; CHECK-NEXT: ld1w { z5.s }, p0/z, [x3] +; CHECK-NEXT: ld1w { z6.s }, p0/z, [x1] +; CHECK-NEXT: ld1w { z24.s }, p0/z, [x0] +; CHECK-NEXT: st1w { z7.s }, p0, [x9] +; CHECK-NEXT: st1w { z24.s }, p0, [x9] +; CHECK-NEXT: st1w { z6.s }, p0, [x9] +; CHECK-NEXT: st1w { z5.s }, p0, [x9] +; CHECK-NEXT: st1w { z4.s }, p0, [x9] +; CHECK-NEXT: st1w { z3.s }, p0, [x9] +; CHECK-NEXT: st1w { z2.s }, p0, [x9] +; CHECK-NEXT: st1w { z1.s }, p0, [x9] +; CHECK-NEXT: st1w { z0.s }, p0, [x9] +; CHECK-NEXT: ret +entry: + %ptr1.bc = bitcast float * %ptr to * + store volatile %s7, * %ptr1.bc + store volatile %s8, * %ptr1.bc + store volatile %s9, * %ptr1.bc + store volatile %s11, * %ptr1.bc + store volatile %s12, * %ptr1.bc + store volatile %s13, * %ptr1.bc + store volatile %s14, * %ptr1.bc + store volatile %s15, * %ptr1.bc + store volatile %s16, * %ptr1.bc + ret void +} + +; Use AAVPCS, no SVE register in z0-z7 used (floats occupy z0-z7) but predicate arg is used + +define void @aavpcs3(float %s0, float %s1, float %s2, float %s3, float %s4, float %s5, float %s6, float %s7, %s8, %s9, %s10, %s11, %s12, %s13, %s14, %s15, %s16, %s17, %p0, float * %ptr) nounwind { +; CHECK-LABEL: aavpcs3: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr x8, [sp] +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x7] +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x6] +; CHECK-NEXT: ld1w { z3.s }, p0/z, [x5] +; CHECK-NEXT: ld1w { z4.s }, p0/z, [x4] +; CHECK-NEXT: ld1w { z5.s }, p0/z, [x3] +; CHECK-NEXT: ld1w { z6.s }, p0/z, [x2] +; CHECK-NEXT: ld1w { z7.s }, p0/z, [x1] +; CHECK-NEXT: ld1w { z24.s }, p0/z, [x0] +; CHECK-NEXT: ldr x8, [sp, #16] +; CHECK-NEXT: st1w { z24.s }, p0, [x8] +; CHECK-NEXT: st1w { z7.s }, p0, [x8] +; CHECK-NEXT: st1w { z6.s }, p0, [x8] +; CHECK-NEXT: st1w { z5.s }, p0, [x8] +; CHECK-NEXT: st1w { z4.s }, p0, [x8] +; CHECK-NEXT: st1w { z3.s }, p0, [x8] +; CHECK-NEXT: st1w { z2.s }, p0, [x8] +; CHECK-NEXT: st1w { z1.s }, p0, [x8] +; CHECK-NEXT: st1w { z0.s }, p0, [x8] +; CHECK-NEXT: ret +entry: + %ptr1.bc = bitcast float * %ptr to * + store volatile %s8, * %ptr1.bc + store volatile %s9, * %ptr1.bc + store volatile %s10, * %ptr1.bc + store volatile %s11, * %ptr1.bc + store volatile %s12, * %ptr1.bc + store volatile %s13, * %ptr1.bc + store volatile %s14, * %ptr1.bc + store volatile %s15, * %ptr1.bc + store volatile %s16, * %ptr1.bc + ret void +} + +; use AAVPCS, SVE register in z0-z7 used (i32s dont occupy z0-z7) + +define void @aavpcs4(i32 %s0, i32 %s1, i32 %s2, i32 %s3, i32 %s4, i32 %s5, i32 %s6, i32 %s7, %s8, %s9, %s10, %s11, %s12, %s13, %s14, %s15, %s16, %s17, i32 * %ptr) nounwind { +; CHECK-LABEL: aavpcs4: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr x8, [sp] +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ldr x9, [sp, #16] +; CHECK-NEXT: ld1w { z24.s }, p0/z, [x8] +; CHECK-NEXT: st1w { z0.s }, p0, [x9] +; CHECK-NEXT: st1w { z1.s }, p0, [x9] +; CHECK-NEXT: st1w { z2.s }, p0, [x9] +; CHECK-NEXT: st1w { z3.s }, p0, [x9] +; CHECK-NEXT: st1w { z4.s }, p0, [x9] +; CHECK-NEXT: st1w { z5.s }, p0, [x9] +; CHECK-NEXT: st1w { z6.s }, p0, [x9] +; CHECK-NEXT: st1w { z7.s }, p0, [x9] +; CHECK-NEXT: st1w { z24.s }, p0, [x9] +; CHECK-NEXT: ret +entry: + %ptr1.bc = bitcast i32 * %ptr to * + store volatile %s8, * %ptr1.bc + store volatile %s9, * %ptr1.bc + store volatile %s10, * %ptr1.bc + store volatile %s11, * %ptr1.bc + store volatile %s12, * %ptr1.bc + store volatile %s13, * %ptr1.bc + store volatile %s14, * %ptr1.bc + store volatile %s15, * %ptr1.bc + store volatile %s16, * %ptr1.bc + ret void +} + +; Use AAPCS, no SVE register in z0-7 used (floats occupy z0-z7) + +define void @aapcs1(float %s0, float %s1, float %s2, float %s3, float %s4, float %s5, float %s6, float %s7, %s8, %s9, %s10, %s11, %s12, %s13, %s14, %s15, %s16, %s17, float * %ptr) nounwind { +; CHECK-LABEL: aapcs1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr x8, [sp] +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x8] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x7] +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x6] +; CHECK-NEXT: ld1w { z3.s }, p0/z, [x5] +; CHECK-NEXT: ld1w { z4.s }, p0/z, [x4] +; CHECK-NEXT: ld1w { z5.s }, p0/z, [x3] +; CHECK-NEXT: ld1w { z6.s }, p0/z, [x2] +; CHECK-NEXT: ld1w { z7.s }, p0/z, [x1] +; CHECK-NEXT: ld1w { z16.s }, p0/z, [x0] +; CHECK-NEXT: ldr x8, [sp, #16] +; CHECK-NEXT: st1w { z16.s }, p0, [x8] +; CHECK-NEXT: st1w { z7.s }, p0, [x8] +; CHECK-NEXT: st1w { z6.s }, p0, [x8] +; CHECK-NEXT: st1w { z5.s }, p0, [x8] +; CHECK-NEXT: st1w { z4.s }, p0, [x8] +; CHECK-NEXT: st1w { z3.s }, p0, [x8] +; CHECK-NEXT: st1w { z2.s }, p0, [x8] +; CHECK-NEXT: st1w { z1.s }, p0, [x8] +; CHECK-NEXT: st1w { z0.s }, p0, [x8] +; CHECK-NEXT: ret +entry: + %ptr1.bc = bitcast float * %ptr to * + store volatile %s8, * %ptr1.bc + store volatile %s9, * %ptr1.bc + store volatile %s10, * %ptr1.bc + store volatile %s11, * %ptr1.bc + store volatile %s12, * %ptr1.bc + store volatile %s13, * %ptr1.bc + store volatile %s14, * %ptr1.bc + store volatile %s15, * %ptr1.bc + store volatile %s16, * %ptr1.bc + ret void +} + declare float @callee1(float, , , ) declare float @callee2(i32, i32, i32, i32, i32, i32, i32, i32, float, , ) declare float @callee3(float, float, , , )