diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -1536,6 +1536,7 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits) const { assert((Node->getOpcode() == ISD::ADD || Node->getOpcode() == ISD::SUB || Node->getOpcode() == ISD::MUL || Node->getOpcode() == ISD::SHL || + Node->getOpcode() == ISD::SRL || Node->getOpcode() == ISD::SIGN_EXTEND_INREG || isa(Node)) && "Unexpected opcode"); diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1256,7 +1256,7 @@ // PatFrag to allow ADDW/SUBW/MULW/SLLW to be selected from i64 add/sub/mul/shl // if only the lower 32 bits of their result is used. -class overflowingbinopw +class binop_allwusers : PatFrag<(ops node:$lhs, node:$rhs), (operator node:$lhs, node:$rhs), [{ return hasAllWUsers(Node); @@ -1291,12 +1291,17 @@ def : PatGprGpr, SRLW>; def : PatGprGpr, SRAW>; -// Select W instructions without sext_inreg if only the lower 32 bits of the -// result are used. -def : PatGprGpr, ADDW>; -def : PatGprSimm12, ADDIW>; -def : PatGprGpr, SUBW>; -def : PatGprImm, SLLIW, uimm5>; +// Select W instructions if only the lower 32 bits of the result are used. +def : PatGprGpr, ADDW>; +def : PatGprSimm12, ADDIW>; +def : PatGprGpr, SUBW>; +def : PatGprImm, SLLIW, uimm5>; + +// If this is a shr of a value sign extended from i32, and all the users only +// use the lower 32 bits, we can use an sraiw to remove the sext_inreg. This +// occurs because SimplifyDemandedBits prefers srl over sra. +def : Pat<(binop_allwusers (sext_inreg GPR:$rs1, i32), uimm5:$shamt), + (SRAIW GPR:$rs1, uimm5:$shamt)>; /// Loads @@ -1339,7 +1344,7 @@ let Predicates = [IsRV64] in { // Select W instructions if only the lower 32-bits of the result are used. -def : Pat<(overflowingbinopw GPR:$rs1, (AddiPair:$rs2)), +def : Pat<(binop_allwusers GPR:$rs1, (AddiPair:$rs2)), (ADDIW (ADDIW GPR:$rs1, (AddiPairImmB AddiPair:$rs2)), (AddiPairImmA AddiPair:$rs2))>; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -72,9 +72,8 @@ } // Predicates = [HasStdExtM] let Predicates = [HasStdExtM, IsRV64] in { -// Select W instructions without sext_inreg if only the lower 32-bits of the -// result are used. -def : PatGprGpr, MULW>; +// Select W instructions if only the lower 32-bits of the result are used. +def : PatGprGpr, MULW>; def : PatGprGpr; def : PatGprGpr; diff --git a/llvm/test/CodeGen/RISCV/rv64i-exhaustive-w-insts.ll b/llvm/test/CodeGen/RISCV/rv64i-exhaustive-w-insts.ll --- a/llvm/test/CodeGen/RISCV/rv64i-exhaustive-w-insts.ll +++ b/llvm/test/CodeGen/RISCV/rv64i-exhaustive-w-insts.ll @@ -1964,8 +1964,7 @@ ; ; RV64ZBA-LABEL: zext_sraiw_aext: ; RV64ZBA: # %bb.0: -; RV64ZBA-NEXT: sext.w a0, a0 -; RV64ZBA-NEXT: srli a0, a0, 7 +; RV64ZBA-NEXT: sraiw a0, a0, 7 ; RV64ZBA-NEXT: zext.w a0, a0 ; RV64ZBA-NEXT: ret %1 = ashr i32 %a, 7 @@ -1999,8 +1998,7 @@ ; ; RV64ZBA-LABEL: zext_sraiw_zext: ; RV64ZBA: # %bb.0: -; RV64ZBA-NEXT: sext.w a0, a0 -; RV64ZBA-NEXT: srli a0, a0, 9 +; RV64ZBA-NEXT: sraiw a0, a0, 9 ; RV64ZBA-NEXT: zext.w a0, a0 ; RV64ZBA-NEXT: ret %1 = ashr i32 %a, 9 diff --git a/llvm/test/CodeGen/RISCV/srem-lkk.ll b/llvm/test/CodeGen/RISCV/srem-lkk.ll --- a/llvm/test/CodeGen/RISCV/srem-lkk.ll +++ b/llvm/test/CodeGen/RISCV/srem-lkk.ll @@ -53,7 +53,7 @@ ; RV64IM-NEXT: srli a1, a1, 32 ; RV64IM-NEXT: addw a1, a1, a0 ; RV64IM-NEXT: srliw a2, a1, 31 -; RV64IM-NEXT: srli a1, a1, 6 +; RV64IM-NEXT: sraiw a1, a1, 6 ; RV64IM-NEXT: addw a1, a1, a2 ; RV64IM-NEXT: addi a2, zero, 95 ; RV64IM-NEXT: mulw a1, a1, a2