diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -57,7 +57,6 @@ bool selectSExti32(SDValue N, SDValue &Val); bool selectZExti32(SDValue N, SDValue &Val); - bool MatchSRLIW(SDNode *N) const; bool MatchSLLIUW(SDNode *N) const; bool selectVLOp(SDValue N, SDValue &VL); diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -1143,27 +1143,6 @@ return false; } -// Match (srl (and val, mask), imm) where the result would be a -// zero-extended 32-bit integer. i.e. the mask is 0xffffffff or the result -// is equivalent to this (SimplifyDemandedBits may have removed lower bits -// from the mask that aren't necessary due to the right-shifting). -bool RISCVDAGToDAGISel::MatchSRLIW(SDNode *N) const { - assert(N->getOpcode() == ISD::SRL); - assert(N->getOperand(0).getOpcode() == ISD::AND); - assert(isa(N->getOperand(1))); - assert(isa(N->getOperand(0).getOperand(1))); - - // The IsRV64 predicate is checked after PatFrag predicates so we can get - // here even on RV32. - if (!Subtarget->is64Bit()) - return false; - - SDValue And = N->getOperand(0); - uint64_t ShAmt = N->getConstantOperandVal(1); - uint64_t Mask = And.getConstantOperandVal(1); - return (Mask | maskTrailingOnes(ShAmt)) == 0xffffffff; -} - // Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32 // on RV64). // SLLIUW is the same as SLLI except for the fact that it clears the bits diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4872,16 +4872,36 @@ // Clear all non-demanded bits initially. APInt ShrunkMask = Mask & DemandedBits; + // Try to make a smaller immediate by setting undemanded bits. + + APInt ExpandedMask = Mask | ~DemandedBits; + + auto IsLegalMask = [ShrunkMask, ExpandedMask](const APInt &Mask) -> bool { + return ShrunkMask.isSubsetOf(Mask) && Mask.isSubsetOf(ExpandedMask); + }; + auto UseMask = [Mask, Op, VT, &TLO](const APInt &NewMask) -> bool { + if (NewMask == Mask) + return true; + SDLoc DL(Op); + SDValue NewC = TLO.DAG.getConstant(NewMask, DL, VT); + SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC); + return TLO.CombineTo(Op, NewOp); + }; + // If the shrunk mask fits in sign extended 12 bits, let the target // independent code apply it. if (ShrunkMask.isSignedIntN(12)) return false; - // Try to make a smaller immediate by setting undemanded bits. + // Try to preserve (and X, 0xffffffff), the (zext_inreg X, i32) pattern. + if (VT == MVT::i64) { + APInt NewMask = APInt(64, 0xffffffff); + if (IsLegalMask(NewMask)) + return UseMask(NewMask); + } - // We need to be able to make a negative number through a combination of mask - // and undemanded bits. - APInt ExpandedMask = Mask | ~DemandedBits; + // For the remaining optimizations, we need to be able to make a negative + // number through a combination of mask and undemanded bits. if (!ExpandedMask.isNegative()) return false; @@ -4899,18 +4919,8 @@ return false; // Sanity check that our new mask is a subset of the demanded mask. - assert(NewMask.isSubsetOf(ExpandedMask)); - - // If we aren't changing the mask, just return true to keep it and prevent - // the caller from optimizing. - if (NewMask == Mask) - return true; - - // Replace the constant with the new mask. - SDLoc DL(Op); - SDValue NewC = TLO.DAG.getConstant(NewMask, DL, VT); - SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC); - return TLO.CombineTo(Op, NewOp); + assert(IsLegalMask(NewMask)); + return UseMask(NewMask); } void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -847,11 +847,6 @@ }]>; def zexti32 : ComplexPattern; -def SRLIWPat : PatFrag<(ops node:$A, node:$B), - (srl (and node:$A, imm), node:$B), [{ - return MatchSRLIW(N); -}]>; - // Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32 // on RV64). Also used to optimize the same sequence without SLLIUW. def SLLIUWPat : PatFrag<(ops node:$A, node:$B), @@ -1164,7 +1159,7 @@ (SUBW GPR:$rs1, GPR:$rs2)>; def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32), (SLLIW GPR:$rs1, uimm5:$shamt)>; -def : Pat<(i64 (SRLIWPat GPR:$rs1, uimm5:$shamt)), +def : Pat<(i64 (srl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)), (SRLIW GPR:$rs1, uimm5:$shamt)>; def : Pat<(i64 (srl (shl GPR:$rs1, (i64 32)), uimm6gt32:$shamt)), (SRLIW GPR:$rs1, (ImmSub32 uimm6gt32:$shamt))>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td @@ -871,6 +871,6 @@ i32)), (PACKW GPR:$rs1, GPR:$rs2)>; def : Pat<(i64 (or (and (assertsexti32 GPR:$rs2), 0xFFFFFFFFFFFF0000), - (SRLIWPat GPR:$rs1, (i64 16)))), + (srl (and GPR:$rs1, 0xFFFFFFFF), (i64 16)))), (PACKUW GPR:$rs1, GPR:$rs2)>; } // Predicates = [HasStdExtZbp, IsRV64] diff --git a/llvm/test/CodeGen/RISCV/alu32.ll b/llvm/test/CodeGen/RISCV/alu32.ll --- a/llvm/test/CodeGen/RISCV/alu32.ll +++ b/llvm/test/CodeGen/RISCV/alu32.ll @@ -129,8 +129,8 @@ ret i32 %1 } -; FIXME: This should use srliw on RV64, but SimplifyDemandedBits breaks the -; (and X, 0xffffffff) that type legalization inserts. +; This makes sure SimplifyDemandedBits doesn't prevent us from matching SRLIW +; on RV64. define i32 @srli_demandedbits(i32 %0) { ; RV32I-LABEL: srli_demandedbits: ; RV32I: # %bb.0: @@ -140,11 +140,7 @@ ; ; RV64I-LABEL: srli_demandedbits: ; RV64I: # %bb.0: -; RV64I-NEXT: addi a1, zero, 1 -; RV64I-NEXT: slli a1, a1, 32 -; RV64I-NEXT: addi a1, a1, -16 -; RV64I-NEXT: and a0, a0, a1 -; RV64I-NEXT: srli a0, a0, 3 +; RV64I-NEXT: srliw a0, a0, 3 ; RV64I-NEXT: ori a0, a0, 1 ; RV64I-NEXT: ret %2 = lshr i32 %0, 3 diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll --- a/llvm/test/CodeGen/RISCV/rv64zba.ll +++ b/llvm/test/CodeGen/RISCV/rv64zba.ll @@ -126,34 +126,26 @@ ret i64 %and } -; FIXME: This can use zext.w, but we need targetShrinkDemandedConstant to -; to adjust the immediate. +; This makes sure targetShrinkDemandedConstant changes the and immmediate to +; allow zext.w or slli+srli. define i64 @zextw_demandedbits_i64(i64 %0) { ; RV64I-LABEL: zextw_demandedbits_i64: ; RV64I: # %bb.0: -; RV64I-NEXT: addi a1, zero, 1 -; RV64I-NEXT: slli a1, a1, 32 -; RV64I-NEXT: addi a1, a1, -2 -; RV64I-NEXT: and a0, a0, a1 ; RV64I-NEXT: ori a0, a0, 1 +; RV64I-NEXT: slli a0, a0, 32 +; RV64I-NEXT: srli a0, a0, 32 ; RV64I-NEXT: ret ; ; RV64IB-LABEL: zextw_demandedbits_i64: ; RV64IB: # %bb.0: -; RV64IB-NEXT: addi a1, zero, 1 -; RV64IB-NEXT: slli a1, a1, 32 -; RV64IB-NEXT: addi a1, a1, -2 -; RV64IB-NEXT: and a0, a0, a1 ; RV64IB-NEXT: ori a0, a0, 1 +; RV64IB-NEXT: zext.w a0, a0 ; RV64IB-NEXT: ret ; ; RV64IBA-LABEL: zextw_demandedbits_i64: ; RV64IBA: # %bb.0: -; RV64IBA-NEXT: addi a1, zero, 1 -; RV64IBA-NEXT: slli a1, a1, 32 -; RV64IBA-NEXT: addi a1, a1, -2 -; RV64IBA-NEXT: and a0, a0, a1 ; RV64IBA-NEXT: ori a0, a0, 1 +; RV64IBA-NEXT: zext.w a0, a0 ; RV64IBA-NEXT: ret %2 = and i64 %0, 4294967294 %3 = or i64 %2, 1