diff --git a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp --- a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp +++ b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp @@ -58,18 +58,6 @@ return new RISCVSExtWRemoval(); } -// add uses of MI to the Worklist -static void addUses(const MachineInstr &MI, - SmallVectorImpl &Worklist, - MachineRegisterInfo &MRI) { - for (auto &UserOp : MRI.reg_operands(MI.getOperand(0).getReg())) { - const auto *User = UserOp.getParent(); - if (User == &MI) // ignore the def, current MI - continue; - Worklist.push_back(User); - } -} - // returns true if all uses of OrigMI only depend on the lower word of its // output, so we can transform OrigMI to the corresponding W-version. // TODO: handle multiple interdependent transformations @@ -78,114 +66,115 @@ SmallPtrSet Visited; SmallVector Worklist; - Visited.insert(&OrigMI); - addUses(OrigMI, Worklist, MRI); + Worklist.push_back(&OrigMI); while (!Worklist.empty()) { const MachineInstr *MI = Worklist.pop_back_val(); - if (!Visited.insert(MI).second) { - // If we've looped back to OrigMI through a PHI cycle, we can't transform - // LD or LWU, because these operations use all 64 bits of input. - if (MI == &OrigMI) { - unsigned opcode = MI->getOpcode(); - if (opcode == RISCV::LD || opcode == RISCV::LWU) - return false; - } + if (!Visited.insert(MI).second) continue; - } - switch (MI->getOpcode()) { - case RISCV::ADDIW: - case RISCV::ADDW: - case RISCV::DIVUW: - case RISCV::DIVW: - case RISCV::MULW: - case RISCV::REMUW: - case RISCV::REMW: - case RISCV::SLLIW: - case RISCV::SLLW: - case RISCV::SRAIW: - case RISCV::SRAW: - case RISCV::SRLIW: - case RISCV::SRLW: - case RISCV::SUBW: - case RISCV::ROLW: - case RISCV::RORW: - case RISCV::RORIW: - case RISCV::CLZW: - case RISCV::CTZW: - case RISCV::CPOPW: - case RISCV::SLLI_UW: - case RISCV::FMV_H_X: - case RISCV::FMV_W_X: - case RISCV::FCVT_H_W: - case RISCV::FCVT_H_WU: - case RISCV::FCVT_S_W: - case RISCV::FCVT_S_WU: - case RISCV::FCVT_D_W: - case RISCV::FCVT_D_WU: - case RISCV::SEXT_B: - case RISCV::SEXT_H: - case RISCV::ZEXT_H_RV64: - continue; + // Only handle instructions with one def. + if (MI->getNumExplicitDefs() != 1) + return false; - // these overwrite higher input bits, otherwise the lower word of output - // depends only on the lower word of input. So check their uses read W. - case RISCV::SLLI: - if (MI->getOperand(2).getImm() >= 32) - continue; - addUses(*MI, Worklist, MRI); - continue; - case RISCV::ANDI: - if (isUInt<11>(MI->getOperand(2).getImm())) - continue; - addUses(*MI, Worklist, MRI); - continue; - case RISCV::ORI: - if (!isUInt<11>(MI->getOperand(2).getImm())) - continue; - addUses(*MI, Worklist, MRI); - continue; + for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { + const MachineInstr *UserMI = UserOp.getParent(); - case RISCV::BEXTI: - if (MI->getOperand(2).getImm() >= 32) + switch (UserMI->getOpcode()) { + default: return false; - continue; - // For these, lower word of output in these operations, depends only on - // the lower word of input. So, we check all uses only read lower word. - case RISCV::COPY: - case RISCV::PHI: - - case RISCV::ADD: - case RISCV::ADDI: - case RISCV::AND: - case RISCV::MUL: - case RISCV::OR: - case RISCV::SLL: - case RISCV::SUB: - case RISCV::XOR: - case RISCV::XORI: - - case RISCV::ADD_UW: - case RISCV::ANDN: - case RISCV::CLMUL: - case RISCV::ORC_B: - case RISCV::ORN: - case RISCV::SH1ADD: - case RISCV::SH1ADD_UW: - case RISCV::SH2ADD: - case RISCV::SH2ADD_UW: - case RISCV::SH3ADD: - case RISCV::SH3ADD_UW: - case RISCV::XNOR: - addUses(*MI, Worklist, MRI); - continue; - default: - return false; + case RISCV::ADDIW: + case RISCV::ADDW: + case RISCV::DIVUW: + case RISCV::DIVW: + case RISCV::MULW: + case RISCV::REMUW: + case RISCV::REMW: + case RISCV::SLLIW: + case RISCV::SLLW: + case RISCV::SRAIW: + case RISCV::SRAW: + case RISCV::SRLIW: + case RISCV::SRLW: + case RISCV::SUBW: + case RISCV::ROLW: + case RISCV::RORW: + case RISCV::RORIW: + case RISCV::CLZW: + case RISCV::CTZW: + case RISCV::CPOPW: + case RISCV::SLLI_UW: + case RISCV::FMV_H_X: + case RISCV::FMV_W_X: + case RISCV::FCVT_H_W: + case RISCV::FCVT_H_WU: + case RISCV::FCVT_S_W: + case RISCV::FCVT_S_WU: + case RISCV::FCVT_D_W: + case RISCV::FCVT_D_WU: + case RISCV::SEXT_B: + case RISCV::SEXT_H: + case RISCV::ZEXT_H_RV64: + break; + + // these overwrite higher input bits, otherwise the lower word of output + // depends only on the lower word of input. So check their uses read W. + case RISCV::SLLI: + if (UserMI->getOperand(2).getImm() >= 32) + break; + Worklist.push_back(UserMI); + break; + case RISCV::ANDI: + if (isUInt<11>(UserMI->getOperand(2).getImm())) + break; + Worklist.push_back(UserMI); + break; + case RISCV::ORI: + if (!isUInt<11>(UserMI->getOperand(2).getImm())) + break; + Worklist.push_back(UserMI); + break; + + case RISCV::BEXTI: + if (UserMI->getOperand(2).getImm() >= 32) + return false; + break; + + // For these, lower word of output in these operations, depends only on + // the lower word of input. So, we check all uses only read lower word. + case RISCV::COPY: + case RISCV::PHI: + + case RISCV::ADD: + case RISCV::ADDI: + case RISCV::AND: + case RISCV::MUL: + case RISCV::OR: + case RISCV::SLL: + case RISCV::SUB: + case RISCV::XOR: + case RISCV::XORI: + + case RISCV::ADD_UW: + case RISCV::ANDN: + case RISCV::CLMUL: + case RISCV::ORC_B: + case RISCV::ORN: + case RISCV::SH1ADD: + case RISCV::SH1ADD_UW: + case RISCV::SH2ADD: + case RISCV::SH2ADD_UW: + case RISCV::SH3ADD: + case RISCV::SH3ADD_UW: + case RISCV::XNOR: + Worklist.push_back(UserMI); + break; + } } } + return true; }