diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -73,6 +73,13 @@ "use for creating a floating-point immediate value"), cl::init(2)); +cl::opt + DeduceVectorCC(DEBUG_TYPE "-deduce-vector-cc", cl::Hidden, + cl::desc("Automatically turn on vector calling convention " + "for every function that has RVV argument/return " + "type."), + cl::init(false)); + RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI) : TargetLowering(TM), Subtarget(STI) { @@ -15713,7 +15720,7 @@ SDValue Chain = CLI.Chain; SDValue Callee = CLI.Callee; bool &IsTailCall = CLI.IsTailCall; - CallingConv::ID CallConv = CLI.CallConv; + CallingConv::ID &CallConv = CLI.CallConv; bool IsVarArg = CLI.IsVarArg; EVT PtrVT = getPointerTy(DAG.getDataLayout()); MVT XLenVT = Subtarget.getXLenVT(); @@ -15731,6 +15738,28 @@ CallConv == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV); + // Assign locations to each value returned by this call. + SmallVector RVLocs; + CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext()); + analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV); + + // Check callee args/returns for RVV registers and set calling convention + // accordingly. + if (DeduceVectorCC && (CallConv == CallingConv::C || CallConv == CallingConv::Fast)) { + auto HasRVVRegLoc = [](CCValAssign &Loc) { + if (!Loc.isRegLoc()) + return false; + + const auto RegClasses = {&RISCV::VRRegClass, &RISCV::VRM2RegClass, + &RISCV::VRM4RegClass, &RISCV::VRM8RegClass}; + return any_of(RegClasses, [&](const auto *RC) + { return RC->contains(Loc.getLocReg()); }); + }; + if (any_of(RVLocs, HasRVVRegLoc) || any_of(ArgLocs, HasRVVRegLoc)) { + CallConv = CallingConv::RISCV_VectorCall; + } + } + // Check if it's really possible to do a tail call. if (IsTailCall) IsTailCall = isEligibleForTailCallOptimization(ArgCCInfo, CLI, MF, ArgLocs); @@ -15977,11 +16006,6 @@ Chain = DAG.getCALLSEQ_END(Chain, NumBytes, 0, Glue, DL); Glue = Chain.getValue(1); - // Assign locations to each value returned by this call. - SmallVector RVLocs; - CCState RetCCInfo(CallConv, IsVarArg, MF, RVLocs, *DAG.getContext()); - analyzeInputArgs(MF, RetCCInfo, Ins, /*IsRet=*/true, RISCV::CC_RISCV); - // Copy all of the result registers out of their specified physreg. for (auto &VA : RVLocs) { // Copy the value out diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -29,6 +29,8 @@ using namespace llvm; +extern cl::opt DeduceVectorCC; + static cl::opt DisableRegAllocHints("riscv-disable-regalloc-hints", cl::Hidden, cl::init(false), @@ -67,7 +69,9 @@ } bool HasVectorCSR = - MF->getFunction().getCallingConv() == CallingConv::RISCV_VectorCall; + MF->getFunction().getCallingConv() == CallingConv::RISCV_VectorCall || + (MF->getInfo()->isVectorCall() && + DeduceVectorCC); switch (Subtarget.getTargetABI()) { default: