diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -10552,6 +10552,7 @@ // passed with CCValAssign::Indirect. static SDValue unpackFromRegLoc(SelectionDAG &DAG, SDValue Chain, const CCValAssign &VA, const SDLoc &DL, + const ISD::InputArg &In, const RISCVTargetLowering &TLI) { MachineFunction &MF = DAG.getMachineFunction(); MachineRegisterInfo &RegInfo = MF.getRegInfo(); @@ -10562,6 +10563,20 @@ RegInfo.addLiveIn(VA.getLocReg(), VReg); Val = DAG.getCopyFromReg(Chain, DL, VReg, LocVT); + // If input is sign extended from 32 bits, note it for the SExtWRemoval pass. + if (In.isOrigArg()) { + Argument *OrigArg = MF.getFunction().getArg(In.getOrigArgIndex()); + if (OrigArg->getType()->isIntegerTy()) { + unsigned BitWidth = OrigArg->getType()->getIntegerBitWidth(); + // An input zero extended from i31 can also be considered sign extended. + if ((BitWidth <= 32 && In.Flags.isSExt()) || + (BitWidth < 32 && In.Flags.isZExt())) { + RISCVMachineFunctionInfo *RVFI = MF.getInfo(); + RVFI->addSExt32Register(VReg); + } + } + } + if (VA.getLocInfo() == CCValAssign::Indirect) return Val; @@ -10874,7 +10889,7 @@ if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, DL); else if (VA.isRegLoc()) - ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, *this); + ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, Ins[i], *this); else ArgValue = unpackFromMemLoc(DAG, Chain, VA, DL); diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h --- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h +++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h @@ -66,6 +66,9 @@ /// Size of stack frame to save callee saved registers unsigned CalleeSavedStackSize = 0; + /// Registers that have been sign extended from i32. + SmallVector SExt32Registers; + public: RISCVMachineFunctionInfo(const MachineFunction &MF) {} @@ -118,6 +121,9 @@ void setCalleeSavedStackSize(unsigned Size) { CalleeSavedStackSize = Size; } void initializeBaseYamlFields(const yaml::RISCVMachineFunctionInfo &YamlMFI); + + void addSExt32Register(Register Reg); + bool isSExt32Register(Register Reg) const; }; } // end namespace llvm diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp @@ -35,3 +35,11 @@ VarArgsFrameIndex = YamlMFI.VarArgsFrameIndex; VarArgsSaveSize = YamlMFI.VarArgsSaveSize; } + +void RISCVMachineFunctionInfo::addSExt32Register(Register Reg) { + SExt32Registers.push_back(Reg); +} + +bool RISCVMachineFunctionInfo::isSExt32Register(Register Reg) const { + return is_contained(SExt32Registers, Reg); +} diff --git a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp --- a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp +++ b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp @@ -11,6 +11,7 @@ //===---------------------------------------------------------------------===// #include "RISCV.h" +#include "RISCVMachineFunctionInfo.h" #include "RISCVSubtarget.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -315,9 +316,17 @@ // Unknown opcode, give up. return false; case RISCV::COPY: { - Register SrcReg = MI->getOperand(1).getReg(); + const MachineFunction *MF = MI->getMF(); + const RISCVMachineFunctionInfo *RVFI = MF->getInfo(); + if (MI->getParent()->getBasicBlock() == &MF->getFunction().getEntryBlock()) { + Register VReg = MI->getOperand(0).getReg(); + if (MF->getRegInfo().isLiveIn(VReg)) + return RVFI->isSExt32Register(VReg); + } - // TODO: Handle arguments and returns from calls? + // TODO: Handle returns from calls? + + Register SrcReg = MI->getOperand(1).getReg(); // If this is a copy from another register, check its source instruction. if (!SrcReg.isVirtual()) diff --git a/llvm/test/CodeGen/RISCV/select-cc.ll b/llvm/test/CodeGen/RISCV/select-cc.ll --- a/llvm/test/CodeGen/RISCV/select-cc.ll +++ b/llvm/test/CodeGen/RISCV/select-cc.ll @@ -114,26 +114,22 @@ ; RV64I-NEXT: mv a0, a2 ; RV64I-NEXT: .LBB0_12: ; RV64I-NEXT: lw a2, 0(a1) -; RV64I-NEXT: sext.w a3, a0 -; RV64I-NEXT: blt a2, a3, .LBB0_14 +; RV64I-NEXT: blt a2, a0, .LBB0_14 ; RV64I-NEXT: # %bb.13: ; RV64I-NEXT: mv a0, a2 ; RV64I-NEXT: .LBB0_14: ; RV64I-NEXT: lw a2, 0(a1) -; RV64I-NEXT: sext.w a3, a0 -; RV64I-NEXT: bge a3, a2, .LBB0_16 +; RV64I-NEXT: bge a0, a2, .LBB0_16 ; RV64I-NEXT: # %bb.15: ; RV64I-NEXT: mv a0, a2 ; RV64I-NEXT: .LBB0_16: ; RV64I-NEXT: lw a2, 0(a1) -; RV64I-NEXT: sext.w a3, a0 -; RV64I-NEXT: blt a3, a2, .LBB0_18 +; RV64I-NEXT: blt a0, a2, .LBB0_18 ; RV64I-NEXT: # %bb.17: ; RV64I-NEXT: mv a0, a2 ; RV64I-NEXT: .LBB0_18: ; RV64I-NEXT: lw a2, 0(a1) -; RV64I-NEXT: sext.w a3, a0 -; RV64I-NEXT: bge a2, a3, .LBB0_20 +; RV64I-NEXT: bge a2, a0, .LBB0_20 ; RV64I-NEXT: # %bb.19: ; RV64I-NEXT: mv a0, a2 ; RV64I-NEXT: .LBB0_20: @@ -159,7 +155,6 @@ ; RV64I-NEXT: # %bb.27: ; RV64I-NEXT: mv a0, a1 ; RV64I-NEXT: .LBB0_28: -; RV64I-NEXT: sext.w a0, a0 ; RV64I-NEXT: ret %val1 = load volatile i32, i32* %b %tst1 = icmp eq i32 %a, %val1