diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2332,6 +2332,9 @@ bool isRVVType(unsigned Bitwidth, bool IsFloat) const; + bool isRVVTupleType() const; + bool isRVVTupleType(unsigned NumGroups) const; + /// Return the implicit lifetime for this type, which must not be dependent. Qualifiers::ObjCLifetime getObjCARCImplicitLifetime() const; @@ -7279,6 +7282,25 @@ return Ret; } +inline bool Type::isRVVTupleType() const { +#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \ + IsFP) \ + (isSpecificBuiltinType(BuiltinType::Id) && NF != 1) || + return +#include "clang/Basic/RISCVVTypes.def" + false; // end of boolean or operation. +} + +inline bool Type::isRVVTupleType(unsigned NumGroups) const { + bool Ret = false; +#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \ + IsFP) \ + if (NF == NumGroups) \ + Ret |= isSpecificBuiltinType(BuiltinType::Id); +#include "clang/Basic/RISCVVTypes.def" + return Ret; +} + inline bool Type::isTemplateTypeParmType() const { return isa(CanonicalType); } diff --git a/clang/lib/CodeGen/Targets/RISCV.cpp b/clang/lib/CodeGen/Targets/RISCV.cpp --- a/clang/lib/CodeGen/Targets/RISCV.cpp +++ b/clang/lib/CodeGen/Targets/RISCV.cpp @@ -8,6 +8,7 @@ #include "ABIInfoImpl.h" #include "TargetInfo.h" +#include "llvm/TargetParser/RISCVTargetParser.h" using namespace clang; using namespace clang::CodeGen; @@ -19,6 +20,9 @@ namespace { class RISCVABIInfo : public DefaultABIInfo { private: + using ArgRegPair = std::pair; + using ArgRegPairs = llvm::SmallVector; + // Size of the integer ('x') registers in bits. unsigned XLen; // Size of the floating point ('f') registers in bits. Note that the target @@ -27,11 +31,15 @@ unsigned FLen; static const int NumArgGPRs = 8; static const int NumArgFPRs = 8; + static const int NumArgVRs = 16; bool detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff, llvm::Type *&Field1Ty, CharUnits &Field1Off, llvm::Type *&Field2Ty, CharUnits &Field2Off) const; + unsigned + computeMaxAssignedRegs(ArgRegPairs &RVVArgRegPairs, + std::vector> &MaxRegs) const; public: RISCVABIInfo(CodeGen::CodeGenTypes &CGT, unsigned XLen, unsigned FLen) @@ -41,6 +49,9 @@ // non-virtual, but computeInfo is virtual, so we overload it. void computeInfo(CGFunctionInfo &FI) const override; + ArgRegPairs calculateRVVArgVRegs(CGFunctionInfo &FI) const; + void classifyRVVArgumentType(ArgRegPairs RVVArgRegPairs) const; + ABIArgInfo classifyArgumentType(QualType Ty, bool IsFixed, int &ArgGPRsLeft, int &ArgFPRsLeft) const; ABIArgInfo classifyReturnType(QualType RetTy) const; @@ -92,9 +103,98 @@ int ArgNum = 0; for (auto &ArgInfo : FI.arguments()) { bool IsFixed = ArgNum < NumFixedArgs; + ArgNum++; + + if (ArgInfo.type.getTypePtr()->isRVVType()) + continue; + ArgInfo.info = classifyArgumentType(ArgInfo.type, IsFixed, ArgGPRsLeft, ArgFPRsLeft); - ArgNum++; + } + + classifyRVVArgumentType(calculateRVVArgVRegs(FI)); +} + +// Calculate total vregs each RVV argument needs. +RISCVABIInfo::ArgRegPairs +RISCVABIInfo::calculateRVVArgVRegs(CGFunctionInfo &FI) const { + RISCVABIInfo::ArgRegPairs RVVArgRegPairs; + for (auto &ArgInfo : FI.arguments()) { + const QualType &Ty = ArgInfo.type; + if (!Ty->isRVVType()) + continue; + + // Calcluate the registers needed for each RVV type. + unsigned ElemSize = Ty->isRVVType(8, false) ? 8 + : Ty->isRVVType(16, false) ? 16 + : Ty->isRVVType(32, false) ? 32 + : 64; + unsigned ElemCount = Ty->isRVVType(1) ? 1 + : Ty->isRVVType(2) ? 2 + : Ty->isRVVType(4) ? 4 + : Ty->isRVVType(8) ? 8 + : Ty->isRVVType(16) ? 16 + : Ty->isRVVType(32) ? 32 + : 64; + unsigned RegsPerGroup = + std::max((ElemSize * ElemCount) / llvm::RISCV::RVVBitsPerBlock, 1U); + + unsigned NumGroups = 1; + if (Ty->isRVVTupleType()) + // Get the number of groups(NF) for each RVV type. + NumGroups = Ty->isRVVTupleType(2) ? 2 + : Ty->isRVVTupleType(3) ? 3 + : Ty->isRVVTupleType(4) ? 4 + : Ty->isRVVTupleType(5) ? 5 + : Ty->isRVVTupleType(6) ? 6 + : Ty->isRVVTupleType(7) ? 7 + : 8; + + RVVArgRegPairs.push_back( + std::make_pair(&ArgInfo, NumGroups * RegsPerGroup)); + } + + return RVVArgRegPairs; +} + +// Dynamic programming approach for finding the best vector register usages. +// We can deduce the problem to 0/1 knapsack problem with: +// 1. capacity == NumArgVRs +// 2. weight == value == total VRs needed +unsigned RISCVABIInfo::computeMaxAssignedRegs( + ArgRegPairs &RVVArgRegPairs, + std::vector> &MaxRegs) const { + for (unsigned i = 1; i <= RVVArgRegPairs.size(); ++i) { + unsigned RegsNeeded = RVVArgRegPairs[i - 1].second; + for (unsigned j = 1; j <= NumArgVRs; ++j) + if (j < RegsNeeded) + MaxRegs[i][j] = MaxRegs[i - 1][j]; + else + MaxRegs[i][j] = std::max(RegsNeeded + MaxRegs[i - 1][j - RegsNeeded], + MaxRegs[i - 1][j]); + } + + return MaxRegs[RVVArgRegPairs.size()][NumArgVRs]; +} + +void RISCVABIInfo::classifyRVVArgumentType(ArgRegPairs RVVArgRegPairs) const { + unsigned ToBeAssigned = RVVArgRegPairs.size(); + std::vector> MaxRegs( + ToBeAssigned + 1, std::vector(NumArgVRs + 1, 0)); + computeMaxAssignedRegs(RVVArgRegPairs, MaxRegs); + + // Walk back through MaxRegs to determine which argument is passed by + // register. + unsigned RegsLeft = NumArgVRs; + while (ToBeAssigned--) { + auto *ArgInfo = RVVArgRegPairs[ToBeAssigned].first; + if (!RegsLeft || + MaxRegs[ToBeAssigned + 1][RegsLeft] == MaxRegs[ToBeAssigned][RegsLeft]) + ArgInfo->info = getNaturalAlignIndirect(ArgInfo->type, /*ByVal=*/false); + else { + ArgInfo->info = ABIArgInfo::getDirect(); + RegsLeft -= RVVArgRegPairs[ToBeAssigned].second; + } } } diff --git a/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c @@ -0,0 +1,26 @@ +// REQUIRES: riscv-registered-target +// RUN: %clang_cc1 -triple riscv64 -target-feature +v \ +// RUN: -emit-llvm %s -o - | FileCheck -check-prefix=CHECK-LLVM %s + +#include + +// CHECK-LLVM: void @call1( %v0, %v1.coerce0, %v1.coerce1, %v2, %v3) +void call1(vint32m2_t v0, vint32m4x2_t v1, vint32m4_t v2, vint32m1_t v3) {} + +// CHECK-LLVM: void @call2( %v0.coerce0, %v0.coerce1, %v0.coerce2, %v1.coerce0, %v1.coerce1, %v2, ptr noundef %0) +void call2(vint32m1x3_t v0, vint32m4x2_t v1, vint32m4_t v2, vint32m2_t v3) {} + +// CHECK-LLVM: void @call3( %v0.coerce0, %v0.coerce1, ptr noundef %0, %v2.coerce0, %v2.coerce1) +void call3(vint32m4x2_t v0, vint32m1_t v1, vint32m4x2_t v2) {} + +// CHECK-LLVM: void @call4( %v0, ptr noundef %0, %v2) +void call4(vint32m8_t v0, vint32m1_t v1, vint32m8_t v2) {} + +// CHECK-LLVM: void @call5(ptr noundef %0, %v1, ptr noundef %1, %v3) +void call5(vint32m1_t v0, vint32m8_t v1, vint32m1_t v2, vint32m8_t v3) {} + +// CHECK-LLVM: void @call6( %v0, %v1, %v2, %v3) +void call6(vint8mf8_t v0, vint8m8_t v1, vint32m1_t v2, vint8mf8_t v3) {} + +// CHECK-LLVM: void @call7(ptr noundef %0, %v1, %v2, ptr noundef %1) +void call7(vint8mf8_t v0, vint8m8_t v1, vint32m8_t v2, vint8mf8_t v3) {}