diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -227,17 +227,6 @@ std::optional getInverseOpcode(unsigned Opcode) const override; - // Returns true if all uses of OrigMI only depend on the lower \p NBits bits - // of its output. - bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI, - unsigned NBits) const; - // Returns true if all uses of OrigMI only depend on the lower word of its - // output, so we can transform OrigMI to the corresponding W-version. - bool hasAllWUsers(const MachineInstr &MI, - const MachineRegisterInfo &MRI) const { - return hasAllNBitUsers(MI, MRI, 32); - } - protected: const RISCVSubtarget &STI; }; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -2614,226 +2614,6 @@ } } -// Checks if all users only demand the lower \p OrigBits of the original -// instruction's result. -// TODO: handle multiple interdependent transformations -bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI, - const MachineRegisterInfo &MRI, - unsigned OrigBits) const { - - SmallSet, 4> Visited; - SmallVector, 4> Worklist; - - Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); - - while (!Worklist.empty()) { - auto P = Worklist.pop_back_val(); - const MachineInstr *MI = P.first; - unsigned Bits = P.second; - - if (!Visited.insert(P).second) - continue; - - // Only handle instructions with one def. - if (MI->getNumExplicitDefs() != 1) - return false; - - for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { - const MachineInstr *UserMI = UserOp.getParent(); - unsigned OpIdx = UserOp.getOperandNo(); - - switch (UserMI->getOpcode()) { - default: - return false; - - case RISCV::ADDIW: - case RISCV::ADDW: - case RISCV::DIVUW: - case RISCV::DIVW: - case RISCV::MULW: - case RISCV::REMUW: - case RISCV::REMW: - case RISCV::SLLIW: - case RISCV::SLLW: - case RISCV::SRAIW: - case RISCV::SRAW: - case RISCV::SRLIW: - case RISCV::SRLW: - case RISCV::SUBW: - case RISCV::ROLW: - case RISCV::RORW: - case RISCV::RORIW: - case RISCV::CLZW: - case RISCV::CTZW: - case RISCV::CPOPW: - case RISCV::SLLI_UW: - case RISCV::FMV_W_X: - case RISCV::FCVT_H_W: - case RISCV::FCVT_H_WU: - case RISCV::FCVT_S_W: - case RISCV::FCVT_S_WU: - case RISCV::FCVT_D_W: - case RISCV::FCVT_D_WU: - if (Bits >= 32) - break; - return false; - case RISCV::SEXT_B: - case RISCV::PACKH: - if (Bits >= 8) - break; - return false; - case RISCV::SEXT_H: - case RISCV::FMV_H_X: - case RISCV::ZEXT_H_RV32: - case RISCV::ZEXT_H_RV64: - case RISCV::PACKW: - if (Bits >= 16) - break; - return false; - - case RISCV::PACK: - if (Bits >= (STI.getXLen() / 2)) - break; - return false; - - case RISCV::SRLI: { - // If we are shifting right by less than Bits, and users don't demand - // any bits that were shifted into [Bits-1:0], then we can consider this - // as an N-Bit user. - unsigned ShAmt = UserMI->getOperand(2).getImm(); - if (Bits > ShAmt) { - Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); - break; - } - return false; - } - - // these overwrite higher input bits, otherwise the lower word of output - // depends only on the lower word of input. So check their uses read W. - case RISCV::SLLI: - if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm())) - break; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - case RISCV::ANDI: { - uint64_t Imm = UserMI->getOperand(2).getImm(); - if (Bits >= (unsigned)llvm::bit_width(Imm)) - break; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - } - case RISCV::ORI: { - uint64_t Imm = UserMI->getOperand(2).getImm(); - if (Bits >= (unsigned)llvm::bit_width(~Imm)) - break; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - } - - case RISCV::SLL: - case RISCV::BSET: - case RISCV::BCLR: - case RISCV::BINV: - // Operand 2 is the shift amount which uses log2(xlen) bits. - if (OpIdx == 2) { - if (Bits >= Log2_32(STI.getXLen())) - break; - return false; - } - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - - case RISCV::SRA: - case RISCV::SRL: - case RISCV::ROL: - case RISCV::ROR: - // Operand 2 is the shift amount which uses 6 bits. - if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen())) - break; - return false; - - case RISCV::ADD_UW: - case RISCV::SH1ADD_UW: - case RISCV::SH2ADD_UW: - case RISCV::SH3ADD_UW: - // Operand 1 is implicitly zero extended. - if (OpIdx == 1 && Bits >= 32) - break; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - - case RISCV::BEXTI: - if (UserMI->getOperand(2).getImm() >= Bits) - return false; - break; - - case RISCV::SB: - // The first argument is the value to store. - if (OpIdx == 0 && Bits >= 8) - break; - return false; - case RISCV::SH: - // The first argument is the value to store. - if (OpIdx == 0 && Bits >= 16) - break; - return false; - case RISCV::SW: - // The first argument is the value to store. - if (OpIdx == 0 && Bits >= 32) - break; - return false; - - // For these, lower word of output in these operations, depends only on - // the lower word of input. So, we check all uses only read lower word. - case RISCV::COPY: - case RISCV::PHI: - - case RISCV::ADD: - case RISCV::ADDI: - case RISCV::AND: - case RISCV::MUL: - case RISCV::OR: - case RISCV::SUB: - case RISCV::XOR: - case RISCV::XORI: - - case RISCV::ANDN: - case RISCV::BREV8: - case RISCV::CLMUL: - case RISCV::ORC_B: - case RISCV::ORN: - case RISCV::SH1ADD: - case RISCV::SH2ADD: - case RISCV::SH3ADD: - case RISCV::XNOR: - case RISCV::BSETI: - case RISCV::BCLRI: - case RISCV::BINVI: - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - - case RISCV::PseudoCCMOVGPR: - // Either operand 4 or operand 5 is returned by this instruction. If - // only the lower word of the result is used, then only the lower word - // of operand 4 and 5 is used. - if (OpIdx != 4 && OpIdx != 5) - return false; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - - case RISCV::VT_MASKC: - case RISCV::VT_MASKCN: - if (OpIdx != 1) - return false; - Worklist.push_back(std::make_pair(UserMI, Bits)); - break; - } - } - } - - return true; -} - // Returns true if this is the sext.w pattern, addiw rd, rs1, 0. bool RISCV::isSEXT_W(const MachineInstr &MI) { return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() && diff --git a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp --- a/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp +++ b/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp @@ -54,9 +54,9 @@ bool runOnMachineFunction(MachineFunction &MF) override; bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, - MachineRegisterInfo &MRI); + const RISCVSubtarget &ST, MachineRegisterInfo &MRI); bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, - MachineRegisterInfo &MRI); + const RISCVSubtarget &ST, MachineRegisterInfo &MRI); void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); @@ -76,6 +76,231 @@ return new RISCVOptWInstrs(); } +// Checks if all users only demand the lower \p OrigBits of the original +// instruction's result. +// TODO: handle multiple interdependent transformations +static bool hasAllNBitUsers(const MachineInstr &OrigMI, + const RISCVSubtarget &ST, + const MachineRegisterInfo &MRI, unsigned OrigBits) { + + SmallSet, 4> Visited; + SmallVector, 4> Worklist; + + Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); + + while (!Worklist.empty()) { + auto P = Worklist.pop_back_val(); + const MachineInstr *MI = P.first; + unsigned Bits = P.second; + + if (!Visited.insert(P).second) + continue; + + // Only handle instructions with one def. + if (MI->getNumExplicitDefs() != 1) + return false; + + for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { + const MachineInstr *UserMI = UserOp.getParent(); + unsigned OpIdx = UserOp.getOperandNo(); + + switch (UserMI->getOpcode()) { + default: + return false; + + case RISCV::ADDIW: + case RISCV::ADDW: + case RISCV::DIVUW: + case RISCV::DIVW: + case RISCV::MULW: + case RISCV::REMUW: + case RISCV::REMW: + case RISCV::SLLIW: + case RISCV::SLLW: + case RISCV::SRAIW: + case RISCV::SRAW: + case RISCV::SRLIW: + case RISCV::SRLW: + case RISCV::SUBW: + case RISCV::ROLW: + case RISCV::RORW: + case RISCV::RORIW: + case RISCV::CLZW: + case RISCV::CTZW: + case RISCV::CPOPW: + case RISCV::SLLI_UW: + case RISCV::FMV_W_X: + case RISCV::FCVT_H_W: + case RISCV::FCVT_H_WU: + case RISCV::FCVT_S_W: + case RISCV::FCVT_S_WU: + case RISCV::FCVT_D_W: + case RISCV::FCVT_D_WU: + if (Bits >= 32) + break; + return false; + case RISCV::SEXT_B: + case RISCV::PACKH: + if (Bits >= 8) + break; + return false; + case RISCV::SEXT_H: + case RISCV::FMV_H_X: + case RISCV::ZEXT_H_RV32: + case RISCV::ZEXT_H_RV64: + case RISCV::PACKW: + if (Bits >= 16) + break; + return false; + + case RISCV::PACK: + if (Bits >= (ST.getXLen() / 2)) + break; + return false; + + case RISCV::SRLI: { + // If we are shifting right by less than Bits, and users don't demand + // any bits that were shifted into [Bits-1:0], then we can consider this + // as an N-Bit user. + unsigned ShAmt = UserMI->getOperand(2).getImm(); + if (Bits > ShAmt) { + Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); + break; + } + return false; + } + + // these overwrite higher input bits, otherwise the lower word of output + // depends only on the lower word of input. So check their uses read W. + case RISCV::SLLI: + if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) + break; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + case RISCV::ANDI: { + uint64_t Imm = UserMI->getOperand(2).getImm(); + if (Bits >= (unsigned)llvm::bit_width(Imm)) + break; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + } + case RISCV::ORI: { + uint64_t Imm = UserMI->getOperand(2).getImm(); + if (Bits >= (unsigned)llvm::bit_width(~Imm)) + break; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + } + + case RISCV::SLL: + case RISCV::BSET: + case RISCV::BCLR: + case RISCV::BINV: + // Operand 2 is the shift amount which uses log2(xlen) bits. + if (OpIdx == 2) { + if (Bits >= Log2_32(ST.getXLen())) + break; + return false; + } + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + + case RISCV::SRA: + case RISCV::SRL: + case RISCV::ROL: + case RISCV::ROR: + // Operand 2 is the shift amount which uses 6 bits. + if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) + break; + return false; + + case RISCV::ADD_UW: + case RISCV::SH1ADD_UW: + case RISCV::SH2ADD_UW: + case RISCV::SH3ADD_UW: + // Operand 1 is implicitly zero extended. + if (OpIdx == 1 && Bits >= 32) + break; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + + case RISCV::BEXTI: + if (UserMI->getOperand(2).getImm() >= Bits) + return false; + break; + + case RISCV::SB: + // The first argument is the value to store. + if (OpIdx == 0 && Bits >= 8) + break; + return false; + case RISCV::SH: + // The first argument is the value to store. + if (OpIdx == 0 && Bits >= 16) + break; + return false; + case RISCV::SW: + // The first argument is the value to store. + if (OpIdx == 0 && Bits >= 32) + break; + return false; + + // For these, lower word of output in these operations, depends only on + // the lower word of input. So, we check all uses only read lower word. + case RISCV::COPY: + case RISCV::PHI: + + case RISCV::ADD: + case RISCV::ADDI: + case RISCV::AND: + case RISCV::MUL: + case RISCV::OR: + case RISCV::SUB: + case RISCV::XOR: + case RISCV::XORI: + + case RISCV::ANDN: + case RISCV::BREV8: + case RISCV::CLMUL: + case RISCV::ORC_B: + case RISCV::ORN: + case RISCV::SH1ADD: + case RISCV::SH2ADD: + case RISCV::SH3ADD: + case RISCV::XNOR: + case RISCV::BSETI: + case RISCV::BCLRI: + case RISCV::BINVI: + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + + case RISCV::PseudoCCMOVGPR: + // Either operand 4 or operand 5 is returned by this instruction. If + // only the lower word of the result is used, then only the lower word + // of operand 4 and 5 is used. + if (OpIdx != 4 && OpIdx != 5) + return false; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + + case RISCV::VT_MASKC: + case RISCV::VT_MASKCN: + if (OpIdx != 1) + return false; + Worklist.push_back(std::make_pair(UserMI, Bits)); + break; + } + } + } + + return true; +} + +static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, + const MachineRegisterInfo &MRI) { + return hasAllNBitUsers(OrigMI, ST, MRI, 32); +} + // This function returns true if the machine instruction always outputs a value // where bits 63:32 match bit 31. static bool isSignExtendingOpW(const MachineInstr &MI, @@ -110,8 +335,8 @@ return false; } -static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI, - const RISCVInstrInfo &TII, +static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, + const MachineRegisterInfo &MRI, SmallPtrSetImpl &FixableDef) { SmallPtrSet Visited; @@ -300,7 +525,7 @@ case RISCV::LWU: case RISCV::MUL: case RISCV::SUB: - if (TII.hasAllWUsers(*MI, MRI)) { + if (hasAllWUsers(*MI, ST, MRI)) { FixableDef.insert(MI); break; } @@ -335,6 +560,7 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, + const RISCVSubtarget &ST, MachineRegisterInfo &MRI) { if (DisableSExtWRemoval) return false; @@ -355,8 +581,8 @@ // If all users only use the lower bits, this sext.w is redundant. // Or if all definitions reaching MI sign-extend their output, // then sext.w is redundant. - if (!TII.hasAllWUsers(*MI, MRI) && - !isSignExtendedW(SrcReg, MRI, TII, FixableDefs)) + if (!hasAllWUsers(*MI, ST, MRI) && + !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) continue; Register DstReg = MI->getOperand(0).getReg(); @@ -388,6 +614,7 @@ bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, + const RISCVSubtarget &ST, MachineRegisterInfo &MRI) { if (DisableStripWSuffix) return false; @@ -406,7 +633,7 @@ case RISCV::SLLIW: Opc = RISCV::SLLI; break; } - if (TII.hasAllWUsers(MI, MRI)) { + if (hasAllWUsers(MI, ST, MRI)) { MI.setDesc(TII.get(Opc)); MadeChange = true; } @@ -428,8 +655,8 @@ return false; bool MadeChange = false; - MadeChange |= removeSExtWInstrs(MF, TII, MRI); - MadeChange |= stripWSuffixes(MF, TII, MRI); + MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); + MadeChange |= stripWSuffixes(MF, TII, ST, MRI); return MadeChange; }