diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -202,18 +202,51 @@ return VF == 1 ? 1 : ST->getMaxInterleaveFactor(); } - // TODO: We should define RISC-V's own register classes. - // e.g. register class for FPR. + enum RISCVRegisterClass { GPRRC, FPRRC, VRRC }; unsigned getNumberOfRegisters(unsigned ClassID) const { - bool Vector = (ClassID == 1); - if (Vector) { - if (ST->hasVInstructions()) + switch (ClassID) { + case RISCVRegisterClass::GPRRC: + // 31 = 32 GPR - x0 (zero register) + // FIXME: Should we exclude fixed registers like SP, TP or GP? + return 31; + case RISCVRegisterClass::FPRRC: + if (ST->hasStdExtZfh() || ST->hasStdExtF() || ST->hasStdExtD()) return 32; return 0; + case RISCVRegisterClass::VRRC: + // Although there are 32 vector registers, v0 is special in that it is the + // only register that can be used to hold a mask. + // FIXME: Should we conservatively return 31 as the number of usable + // vector registers? + return ST->hasVInstructions() ? 32 : 0; } - // 31 = 32 GPR - x0 (zero register) - // FIXME: Should we exclude fixed registers like SP, TP or GP? - return 31; + llvm_unreachable("unknown register class"); + } + unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const { + if (Vector) + return RISCVRegisterClass::VRRC; + if (!Ty) + return RISCVRegisterClass::GPRRC; + + Type *ScalarTy = Ty->getScalarType(); + if ((ScalarTy->isBFloatTy() && ST->hasStdExtZfh()) || + (ScalarTy->isFloatTy() && ST->hasStdExtF()) || + (ScalarTy->isDoubleTy() && ST->hasStdExtD())) { + return RISCVRegisterClass::FPRRC; + } + + return RISCVRegisterClass::GPRRC; + }; + const char *getRegisterClassName(unsigned ClassID) const { + switch (ClassID) { + case RISCVRegisterClass::GPRRC: + return "RISCV::GPRRC"; + case RISCVRegisterClass::FPRRC: + return "RISCV::FPRRC"; + case RISCVRegisterClass::VRRC: + return "RISCV::VRRC"; + } + llvm_unreachable("unknown register class"); } }; diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/reg-usage.ll b/llvm/test/Transforms/LoopVectorize/RISCV/reg-usage.ll --- a/llvm/test/Transforms/LoopVectorize/RISCV/reg-usage.ll +++ b/llvm/test/Transforms/LoopVectorize/RISCV/reg-usage.ll @@ -19,25 +19,25 @@ define void @add(float* noalias nocapture readonly %src1, float* noalias nocapture readonly %src2, i32 signext %size, float* noalias nocapture writeonly %result) { ; CHECK-LABEL: add ; CHECK-LMUL1: LV(REG): Found max usage: 2 item -; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 2 registers -; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 2 registers +; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: RISCV::GPRRC, 2 registers +; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 2 registers ; CHECK-LMUL1-NEXT: LV(REG): Found invariant usage: 1 item -; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 2 registers +; CHECK-LMUL1-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 2 registers ; CHECK-LMUL2: LV(REG): Found max usage: 2 item -; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 2 registers -; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 4 registers +; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: RISCV::GPRRC, 2 registers +; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 4 registers ; CHECK-LMUL2-NEXT: LV(REG): Found invariant usage: 1 item -; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 4 registers +; CHECK-LMUL2-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 4 registers ; CHECK-LMUL4: LV(REG): Found max usage: 2 item -; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 2 registers -; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 8 registers +; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: RISCV::GPRRC, 2 registers +; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 8 registers ; CHECK-LMUL4-NEXT: LV(REG): Found invariant usage: 1 item -; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 8 registers +; CHECK-LMUL4-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 8 registers ; CHECK-LMUL8: LV(REG): Found max usage: 2 item -; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 2 registers -; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 16 registers +; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: RISCV::GPRRC, 2 registers +; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 16 registers ; CHECK-LMUL8-NEXT: LV(REG): Found invariant usage: 1 item -; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 16 registers +; CHECK-LMUL8-NEXT: LV(REG): RegisterClass: RISCV::VRRC, 16 registers entry: %conv = zext i32 %size to i64