diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7277,6 +7277,7 @@ // Otherwise if matching a general funnel shift, it should be clear. static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize, SelectionDAG &DAG, bool IsRotate) { + const auto &TLI = DAG.getTargetLoweringInfo(); // If EltSize is a power of 2 then: // // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) @@ -7308,19 +7309,19 @@ // always invokes undefined behavior for 32-bit X. // // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise. + // This allows us to peek through any operations that only affect Mask's + // un-demanded bits. // - // NOTE: We can only do this when matching an AND and not a general - // funnel shift. + // NOTE: We can only do this when matching operations which won't modify the + // least Log2(EltSize) significant bits and not a general funnel shift. unsigned MaskLoBits = 0; - if (IsRotate && Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { - if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0)); - unsigned Bits = Log2_64(EltSize); - if (NegC->getAPIntValue().getActiveBits() <= Bits && - ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) { - Neg = Neg.getOperand(0); - MaskLoBits = Bits; - } + if (IsRotate && isPowerOf2_64(EltSize)) { + unsigned Bits = Log2_64(EltSize); + APInt DemandedBits = APInt::getLowBitsSet(EltSize, Bits); + if (SDValue Inner = + TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) { + Neg = Inner; + MaskLoBits = Bits; } } @@ -7332,15 +7333,14 @@ return false; SDValue NegOp1 = Neg.getOperand(1); - // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with - // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && Pos.getOpcode() == ISD::AND) { - if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) { - KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0)); - if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits && - ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >= - MaskLoBits)) - Pos = Pos.getOperand(0); + // On the RHS of [A], if Pos is the result of operation on Pos' that won't + // affect Mask's demanded bits, just replace Pos with Pos'. These operations + // are redundant for the purpose of the equality. + if (MaskLoBits) { + APInt DemandedBits = APInt::getLowBitsSet(EltSize, MaskLoBits); + if (SDValue Inner = + TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) { + Pos = Inner; } } diff --git a/llvm/test/CodeGen/RISCV/rotl-rotr.ll b/llvm/test/CodeGen/RISCV/rotl-rotr.ll --- a/llvm/test/CodeGen/RISCV/rotl-rotr.ll +++ b/llvm/test/CodeGen/RISCV/rotl-rotr.ll @@ -341,18 +341,12 @@ ; ; RV32ZBB-LABEL: rotl_32_mask_and_63_and_31: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: sll a2, a0, a1 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: srl a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: rol a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotl_32_mask_and_63_and_31: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: sllw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: srlw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rolw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i32 %y, 63 %b = shl i32 %x, %a @@ -384,20 +378,12 @@ ; ; RV32ZBB-LABEL: rotl_32_mask_or_64_or_32: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: ori a2, a1, 64 -; RV32ZBB-NEXT: sll a2, a0, a2 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: ori a1, a1, 32 -; RV32ZBB-NEXT: srl a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: rol a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotl_32_mask_or_64_or_32: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: sllw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: srlw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rolw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = or i32 %y, 64 %b = shl i32 %x, %a @@ -461,18 +447,12 @@ ; ; RV32ZBB-LABEL: rotr_32_mask_and_63_and_31: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: srl a2, a0, a1 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: sll a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: ror a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotr_32_mask_and_63_and_31: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: srlw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: sllw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rorw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i32 %y, 63 %b = lshr i32 %x, %a @@ -504,20 +484,12 @@ ; ; RV32ZBB-LABEL: rotr_32_mask_or_64_or_32: ; RV32ZBB: # %bb.0: -; RV32ZBB-NEXT: ori a2, a1, 64 -; RV32ZBB-NEXT: srl a2, a0, a2 -; RV32ZBB-NEXT: neg a1, a1 -; RV32ZBB-NEXT: ori a1, a1, 32 -; RV32ZBB-NEXT: sll a0, a0, a1 -; RV32ZBB-NEXT: or a0, a2, a0 +; RV32ZBB-NEXT: ror a0, a0, a1 ; RV32ZBB-NEXT: ret ; ; RV64ZBB-LABEL: rotr_32_mask_or_64_or_32: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: srlw a2, a0, a1 -; RV64ZBB-NEXT: negw a1, a1 -; RV64ZBB-NEXT: sllw a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rorw a0, a0, a1 ; RV64ZBB-NEXT: ret %a = or i32 %y, 64 %b = lshr i32 %x, %a @@ -718,10 +690,7 @@ ; ; RV64ZBB-LABEL: rotl_64_mask_and_127_and_63: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: sll a2, a0, a1 -; RV64ZBB-NEXT: neg a1, a1 -; RV64ZBB-NEXT: srl a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rol a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i64 %y, 127 %b = shl i64 %x, %a @@ -761,12 +730,7 @@ ; ; RV64ZBB-LABEL: rotl_64_mask_or_128_or_64: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: ori a2, a1, 128 -; RV64ZBB-NEXT: sll a2, a0, a2 -; RV64ZBB-NEXT: neg a1, a1 -; RV64ZBB-NEXT: ori a1, a1, 64 -; RV64ZBB-NEXT: srl a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: rol a0, a0, a1 ; RV64ZBB-NEXT: ret %a = or i64 %y, 128 %b = shl i64 %x, %a @@ -967,10 +931,7 @@ ; ; RV64ZBB-LABEL: rotr_64_mask_and_127_and_63: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: srl a2, a0, a1 -; RV64ZBB-NEXT: neg a1, a1 -; RV64ZBB-NEXT: sll a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: ror a0, a0, a1 ; RV64ZBB-NEXT: ret %a = and i64 %y, 127 %b = lshr i64 %x, %a @@ -1010,12 +971,7 @@ ; ; RV64ZBB-LABEL: rotr_64_mask_or_128_or_64: ; RV64ZBB: # %bb.0: -; RV64ZBB-NEXT: ori a2, a1, 128 -; RV64ZBB-NEXT: srl a2, a0, a2 -; RV64ZBB-NEXT: neg a1, a1 -; RV64ZBB-NEXT: ori a1, a1, 64 -; RV64ZBB-NEXT: sll a0, a0, a1 -; RV64ZBB-NEXT: or a0, a2, a0 +; RV64ZBB-NEXT: ror a0, a0, a1 ; RV64ZBB-NEXT: ret %a = or i64 %y, 128 %b = lshr i64 %x, %a