diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7277,6 +7277,7 @@ // Otherwise if matching a general funnel shift, it should be clear. static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, SelectionDAG &DAG, bool IsRotate) { + const auto &TLI = DAG.getTargetLoweringInfo(); // If EltSize is a power of 2 then: // // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) @@ -7312,15 +7313,13 @@ // NOTE: We can only do this when matching an AND and not a general // funnel shift. unsigned MaskLoBits = 0; - if (IsRotate && Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { - if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0)); - unsigned Bits = Log2_64(EltSize); - if (NegC->getAPIntValue().getActiveBits() <= Bits && - ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) { - Neg = Neg.getOperand(0); - MaskLoBits = Bits; - } + if (IsRotate && isPowerOf2_64(EltSize)) { + unsigned Bits = Log2_64(EltSize); + APInt DemandedBits = APInt::getLowBitsSet(EltSize, Bits); + if (SDValue Inner = + TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) { + Neg = Inner; + MaskLoBits = Bits; } } @@ -7334,13 +7333,11 @@ // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && Pos.getOpcode() == ISD::AND) { - if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0)); - if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits && - ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >= - MaskLoBits)) - Pos = Pos.getOperand(0); + if (MaskLoBits) { + APInt DemandedBits = APInt::getLowBitsSet(EltSize, MaskLoBits); + if (SDValue Inner = + TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) { + Pos = Inner; } } diff --git a/llvm/test/CodeGen/RISCV/rotl-rotr.ll b/llvm/test/CodeGen/RISCV/rotl-rotr.ll --- a/llvm/test/CodeGen/RISCV/rotl-rotr.ll +++ b/llvm/test/CodeGen/RISCV/rotl-rotr.ll @@ -341,18 +341,12 @@ ; ; RV32ZBB-LABEL: rotl_32_mask_1: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: sll a2, a0, a1 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: srl a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: rol a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotl_32_mask_1: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: sllw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: srlw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rolw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i32 %y, 63 %b = shl i32 %x, %a @@ -416,18 +410,12 @@ ; ; RV32ZBB-LABEL: rotr_32_mask_1: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: srl a2, a0, a1 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: sll a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: ror a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotr_32_mask_1: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: srlw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: sllw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rorw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i32 %y, 63 %b = lshr i32 %x, %a