diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1799,16 +1799,30 @@ .addReg(VL, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); - } else if ((NumOfVReg == 3 || NumOfVReg == 5 || NumOfVReg == 9) && - STI.hasStdExtZba()) { - // We can use Zba SHXADD instructions for multiply in some cases. - // TODO: Generalize to SHXADD+SLLI. + } else if (STI.hasStdExtZba() && + ((NumOfVReg % 3 == 0 && isPowerOf2_64(NumOfVReg / 3)) || + (NumOfVReg % 5 == 0 && isPowerOf2_64(NumOfVReg / 5)) || + (NumOfVReg % 9 == 0 && isPowerOf2_64(NumOfVReg / 9)))) { + // We can use Zba SHXADD+SLLI instructions for multiply in some cases. unsigned Opc; - switch (NumOfVReg) { - default: llvm_unreachable("Unexpected number of vregs"); - case 3: Opc = RISCV::SH1ADD; break; - case 5: Opc = RISCV::SH2ADD; break; - case 9: Opc = RISCV::SH3ADD; break; + uint32_t ShiftAmount; + if (NumOfVReg % 3 == 0) { + Opc = RISCV::SH1ADD; + ShiftAmount = Log2_64(NumOfVReg / 3); + } else if (NumOfVReg % 5 == 0) { + Opc = RISCV::SH2ADD; + ShiftAmount = Log2_64(NumOfVReg / 5); + } else { + Opc = RISCV::SH3ADD; + ShiftAmount = Log2_64(NumOfVReg / 9); + } + if (ShiftAmount) { + Register ShiftResReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(MBB, II, DL, get(RISCV::SLLI), ShiftResReg) + .addReg(VL, RegState::Kill) + .addImm(ShiftAmount) + .setMIFlag(Flag); + VL = ShiftResReg; } BuildMI(MBB, II, DL, get(Opc), VL) .addReg(VL, RegState::Kill) diff --git a/llvm/test/CodeGen/RISCV/rvv/allocate-lmul-2-4-8.ll b/llvm/test/CodeGen/RISCV/rvv/allocate-lmul-2-4-8.ll --- a/llvm/test/CodeGen/RISCV/rvv/allocate-lmul-2-4-8.ll +++ b/llvm/test/CodeGen/RISCV/rvv/allocate-lmul-2-4-8.ll @@ -1,8 +1,8 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=riscv64 -mattr=+m,+v -verify-machineinstrs < %s \ -; RUN: | FileCheck %s +; RUN: | FileCheck %s --check-prefixes=CHECK,NOZBA ; RUN: llc -mtriple=riscv64 -mattr=+m,+v,+zba -verify-machineinstrs < %s \ -; RUN: | FileCheck %s +; RUN: | FileCheck %s --check-prefixes=CHECK,ZBA define void @lmul1() nounwind { ; CHECK-LABEL: lmul1: @@ -210,22 +210,39 @@ } define void @lmul4_and_2_x2_1() nounwind { -; CHECK-LABEL: lmul4_and_2_x2_1: -; CHECK: # %bb.0: -; CHECK-NEXT: addi sp, sp, -48 -; CHECK-NEXT: sd ra, 40(sp) # 8-byte Folded Spill -; CHECK-NEXT: sd s0, 32(sp) # 8-byte Folded Spill -; CHECK-NEXT: addi s0, sp, 48 -; CHECK-NEXT: csrr a0, vlenb -; CHECK-NEXT: li a1, 12 -; CHECK-NEXT: mul a0, a0, a1 -; CHECK-NEXT: sub sp, sp, a0 -; CHECK-NEXT: andi sp, sp, -32 -; CHECK-NEXT: addi sp, s0, -48 -; CHECK-NEXT: ld ra, 40(sp) # 8-byte Folded Reload -; CHECK-NEXT: ld s0, 32(sp) # 8-byte Folded Reload -; CHECK-NEXT: addi sp, sp, 48 -; CHECK-NEXT: ret +; NOZBA-LABEL: lmul4_and_2_x2_1: +; NOZBA: # %bb.0: +; NOZBA-NEXT: addi sp, sp, -48 +; NOZBA-NEXT: sd ra, 40(sp) # 8-byte Folded Spill +; NOZBA-NEXT: sd s0, 32(sp) # 8-byte Folded Spill +; NOZBA-NEXT: addi s0, sp, 48 +; NOZBA-NEXT: csrr a0, vlenb +; NOZBA-NEXT: li a1, 12 +; NOZBA-NEXT: mul a0, a0, a1 +; NOZBA-NEXT: sub sp, sp, a0 +; NOZBA-NEXT: andi sp, sp, -32 +; NOZBA-NEXT: addi sp, s0, -48 +; NOZBA-NEXT: ld ra, 40(sp) # 8-byte Folded Reload +; NOZBA-NEXT: ld s0, 32(sp) # 8-byte Folded Reload +; NOZBA-NEXT: addi sp, sp, 48 +; NOZBA-NEXT: ret +; +; ZBA-LABEL: lmul4_and_2_x2_1: +; ZBA: # %bb.0: +; ZBA-NEXT: addi sp, sp, -48 +; ZBA-NEXT: sd ra, 40(sp) # 8-byte Folded Spill +; ZBA-NEXT: sd s0, 32(sp) # 8-byte Folded Spill +; ZBA-NEXT: addi s0, sp, 48 +; ZBA-NEXT: csrr a0, vlenb +; ZBA-NEXT: slli a0, a0, 2 +; ZBA-NEXT: sh1add a0, a0, a0 +; ZBA-NEXT: sub sp, sp, a0 +; ZBA-NEXT: andi sp, sp, -32 +; ZBA-NEXT: addi sp, s0, -48 +; ZBA-NEXT: ld ra, 40(sp) # 8-byte Folded Reload +; ZBA-NEXT: ld s0, 32(sp) # 8-byte Folded Reload +; ZBA-NEXT: addi sp, sp, 48 +; ZBA-NEXT: ret %v1 = alloca %v3 = alloca %v2 = alloca