Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -480,9 +480,6 @@ /// returns false. bool findBetterNeighborChains(StoreSDNode *St); - /// Match "(X shl/srl V1) & V2" where V2 may not be present. - bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask); - /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -4804,16 +4801,19 @@ return SDValue(); } -/// Match "(X shl/srl V1) & V2" where V2 may not be present. -bool DAGCombiner::MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { - if (Op.getOpcode() == ISD::AND) { - if (DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { - Mask = Op.getOperand(1); - Op = Op.getOperand(0); - } else { - return false; - } +static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) { + if (Op.getOpcode() == ISD::AND && + DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { + Mask = Op.getOperand(1); + return Op.getOperand(0); } + return Op; +} + +/// Match "(X shl/srl V1) & V2" where V2 may not be present. +static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift, + SDValue &Mask) { + Op = stripConstantMask(DAG, Op, Mask); if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) { Shift = Op; @@ -4823,6 +4823,89 @@ return false; } +/// Helper function for visitOR to extract one side of a rotate idiom +/// from a mul or div operation. This is meant to handle cases where +/// InstCombine merged some outside mul or udiv with one of the shifts from +/// the rotate pattern. +/// Attempts to expand: +/// (or (shrl (mul v c0) c1) (mul v c2)) -> +/// (or (shrl (mul v c0) c1) (shl (mul v c0) c3)) +/// and +/// (or (udiv v c0) (shl (udiv v c1) c2)) -> +/// (or (shrl (udiv v c1) c3) (shl (udiv v c1) c2)) +static SDValue extractShiftFromMulOrUDiv(SelectionDAG &DAG, SDValue OppShift, + SDValue MulOrDiv, SDValue &Mask, + const SDLoc &DL) { + assert(OppShift && MulOrDiv && "Empty SDValue"); + assert((OppShift.getOpcode() == ISD::SHL || + OppShift.getOpcode() == ISD::SRL) && + "Existing shift must be valid as a rotate half"); + + MulOrDiv = stripConstantMask(DAG, MulOrDiv, Mask); + SDValue ShiftLHS = OppShift.getOperand(0); + EVT ShiftedVT = ShiftLHS.getValueType(); + + // Preconditions: + // + // or( (shiftl/r (op0 v c0) c1) (op1 v c2) ) + // + // op1 is a mul or udiv and is the same as op0 + // op1 and op0 have v as the LHS input and produce the same value type + if ((ShiftLHS.getOpcode() != ISD::MUL && ShiftLHS.getOpcode() != ISD::UDIV) || + ShiftLHS.getOpcode() != MulOrDiv.getOpcode() || + ShiftLHS.getOperand(0) != MulOrDiv.getOperand(0) || + ShiftedVT != MulOrDiv.getValueType()) + return SDValue(); + + unsigned Opcode = 0; + // If both are muls and the opposing is an shrl, we can extract a shl + if (MulOrDiv.getOpcode() == ISD::MUL && OppShift.getOpcode() == ISD::SRL) + Opcode = ISD::SHL; + // If both are udivs and the opposing is an shl, we can extract a shrl + else if (MulOrDiv.getOpcode() == ISD::UDIV && + OppShift.getOpcode() == ISD::SHL) + Opcode = ISD::SRL; + // Can't extract + else + return SDValue(); + + // Amount of the existing shift + ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); + // Constant mul/div amount from the RHS of the shift's LHS op + ConstantSDNode *OppMulOrDivCst = isConstOrConstSplat(ShiftLHS.getOperand(1)); + // Constant mul/div amount from the RHS of the MulOrDiv op + ConstantSDNode *MulOrDivCst = isConstOrConstSplat(MulOrDiv.getOperand(1)); + // Check that we have constant values + if (!OppShiftCst || !OppShiftCst->getAPIntValue() || + !OppMulOrDivCst || !OppMulOrDivCst->getAPIntValue() || + !MulOrDivCst || !MulOrDivCst->getAPIntValue()) + return SDValue(); + + // Check: + // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0 + // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0 + const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); + const APInt NewShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); + const APInt ExtractDiv = + APInt::getOneBitSet(VTWidth, NewShiftAmt.getZExtValue()); + APInt NewMulOrDivAmt; + APInt Rem; + APInt::udivrem(MulOrDivCst->getAPIntValue(), ExtractDiv, NewMulOrDivAmt, Rem); + if (Rem != 0 || NewMulOrDivAmt != OppMulOrDivCst->getAPIntValue()) + return SDValue(); + + // Expand: + // + // or( (shiftl/r (op0 v c0) c1) (op0 v c2) ) -> + // or( (shiftl/r (op0 v c0) c1) (shiftr/l (op0 v c0) c3) ) + // + // such that c1 + c3 == bitwidth(op0 v c0) and a rotate can be formed + EVT ShiftVT = OppShift.getOperand(1).getValueType(); + EVT ResVT = MulOrDiv.getValueType(); + SDValue NewShiftNode = DAG.getConstant(NewShiftAmt, DL, ShiftVT); + return DAG.getNode(Opcode, DL, ResVT, ShiftLHS, NewShiftNode); +} + // Return true if we can prove that, whenever Neg and Pos are both in the // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: @@ -4986,16 +5069,42 @@ } } + bool NeedRotLHS = true; + bool NeedRotRHS = true; + // Match "(X shl/srl V1) & V2" where V2 may not be present. SDValue LHSShift; // The shift. SDValue LHSMask; // AND value if any. - if (!MatchRotateHalf(LHS, LHSShift, LHSMask)) - return nullptr; // Not part of a rotate. + if (matchRotateHalf(DAG, LHS, LHSShift, LHSMask)) + NeedRotLHS = false; SDValue RHSShift; // The shift. SDValue RHSMask; // AND value if any. - if (!MatchRotateHalf(RHS, RHSShift, RHSMask)) - return nullptr; // Not part of a rotate. + if (matchRotateHalf(DAG, RHS, RHSShift, RHSMask)) + NeedRotRHS = false; + + // If neither side matched a rotate half, bail + if (NeedRotLHS && NeedRotRHS) + return nullptr; + + // Iff one side matched a rotate half, we may be able to extract the + // needed shift from the opposite side if it is a constant mul or udiv. + // We also need to pass the matched shift to compute the needed shift amount + // to extract (example: if the side with the shift is (shl v c0), and v has a + // bitwidth of 32, then the needed shift to extract from the opposite side is + // computed to be (srl v (32 - c0))) + + // Have LHS side of the rotate, try to extract the needed shift from the RHS + if (NeedRotRHS) + RHSShift = extractShiftFromMulOrUDiv(DAG, LHSShift, RHS, RHSMask, DL); + // Have RHS side of the rotate, try to extract the needed shift from the LHS + if (NeedRotLHS) + LHSShift = extractShiftFromMulOrUDiv(DAG, RHSShift, LHS, LHSMask, DL); + // If a side is still missing, nothing else we can do + if (!RHSShift || !LHSShift) + return nullptr; + + // At this point we've matched a shift op on each side if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) return nullptr; // Not shifting the same value. Index: test/CodeGen/AArch64/rotate-extract.ll =================================================================== --- test/CodeGen/AArch64/rotate-extract.ll +++ test/CodeGen/AArch64/rotate-extract.ll @@ -0,0 +1,93 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-unknown-unknown | FileCheck %s + +; Check that under certain conditions we can factor out a rotate +; from the following idioms: +; (a*c0) >> s1 | (a*c1) +; (a/c0) << s1 | (a/c1) +; This targets cases where instcombine has folded a mul or +; udiv with one of the shifts of the normal rotate idiom + +define i64 @ror_extract_mul(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x0, lsl #3 +; CHECK-NEXT: ror x0, x8, #57 +; CHECK-NEXT: ret + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 1152 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +define i64 @ror_extract_udiv(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_udiv: +; CHECK: // %bb.0: +; CHECK-NEXT: mov x8, #-6148914691236517206 +; CHECK-NEXT: movk x8, #43691 +; CHECK-NEXT: umulh x8, x0, x8 +; CHECK-NEXT: lsr x8, x8, #1 +; CHECK-NEXT: ror x0, x8, #4 +; CHECK-NEXT: ret + %lhs_div = udiv i64 %i, 3 + %rhs_div = udiv i64 %i, 48 + %lhs_shift = shl i64 %lhs_div, 60 + %out = or i64 %lhs_shift, %rhs_div + ret i64 %out +} + +define i64 @ror_extract_mul_with_mask(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_mul_with_mask: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x0, lsl #3 +; CHECK-NEXT: ror x8, x8, #57 +; CHECK-NEXT: and x0, x8, #0xff +; CHECK-NEXT: ret + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 1152 + %rhs_and = and i64 %rhs_mul, 160 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_and + ret i64 %out +} + +; Can factor 128 from 2304, but result is 18 instead of 9 +define i64 @no_extract_mul(i64 %i) nounwind { +; CHECK-LABEL: no_extract_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x0, lsl #3 +; CHECK-NEXT: lsr x0, x8, #57 +; CHECK-NEXT: bfi x0, x8, #8, #56 +; CHECK-NEXT: ret + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 2304 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +; Can't evenly factor 16 from 49 +define i64 @no_extract_udiv(i64 %i) nounwind { +; CHECK-LABEL: no_extract_udiv: +; CHECK: // %bb.0: +; CHECK-NEXT: mov x9, #38787 +; CHECK-NEXT: movk x9, #61523, lsl #16 +; CHECK-NEXT: movk x9, #2674, lsl #32 +; CHECK-NEXT: movk x9, #20062, lsl #48 +; CHECK-NEXT: mov x8, #-6148914691236517206 +; CHECK-NEXT: umulh x9, x0, x9 +; CHECK-NEXT: movk x8, #43691 +; CHECK-NEXT: sub x10, x0, x9 +; CHECK-NEXT: umulh x8, x0, x8 +; CHECK-NEXT: add x9, x9, x10, lsr #1 +; CHECK-NEXT: lsr x8, x8, #1 +; CHECK-NEXT: lsr x0, x9, #5 +; CHECK-NEXT: bfi x0, x8, #60, #4 +; CHECK-NEXT: ret + %lhs_div = udiv i64 %i, 3 + %rhs_div = udiv i64 %i, 49 + %lhs_shift = shl i64 %lhs_div, 60 + %out = or i64 %lhs_shift, %rhs_div + ret i64 %out +} Index: test/CodeGen/X86/rotate-extract.ll =================================================================== --- test/CodeGen/X86/rotate-extract.ll +++ test/CodeGen/X86/rotate-extract.ll @@ -0,0 +1,98 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s + +; Check that under certain conditions we can factor out a rotate +; from the following idioms: +; (a*c0) >> s1 | (a*c1) +; (a/c0) << s1 | (a/c1) +; This targets cases where instcombine has folded a mul or +; udiv with one of the shifts of the normal rotate idiom + +define i64 @rolq_extract_mul(i64 %i) nounwind { +; CHECK-LABEL: rolq_extract_mul: +; CHECK: # %bb.0: +; CHECK-NEXT: leaq (%rdi,%rdi,8), %rax +; CHECK-NEXT: rolq $7, %rax +; CHECK-NEXT: retq + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 1152 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +define i64 @rolq_extract_udiv(i64 %i) nounwind { +; CHECK-LABEL: rolq_extract_udiv: +; CHECK: # %bb.0: +; CHECK-NEXT: movabsq $-6148914691236517205, %rcx # imm = 0xAAAAAAAAAAAAAAAB +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: mulq %rcx +; CHECK-NEXT: shrq %rdx +; CHECK-NEXT: rolq $60, %rdx +; CHECK-NEXT: movq %rdx, %rax +; CHECK-NEXT: retq + %lhs_div = udiv i64 %i, 3 + %rhs_div = udiv i64 %i, 48 + %lhs_shift = shl i64 %lhs_div, 60 + %out = or i64 %lhs_shift, %rhs_div + ret i64 %out +} + +define i64 @rolq_extract_mul_with_mask(i64 %i) nounwind { +; CHECK-LABEL: rolq_extract_mul_with_mask: +; CHECK: # %bb.0: +; CHECK-NEXT: leaq (%rdi,%rdi,8), %rax +; CHECK-NEXT: rolq $7, %rax +; CHECK-NEXT: movzbl %al, %eax +; CHECK-NEXT: retq + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 1152 + %rhs_and = and i64 %rhs_mul, 160 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_and + ret i64 %out +} + +; Can factor 128 from 2304, but result is 18 instead of 9 +define i64 @no_extract_mul(i64 %i) nounwind { +; CHECK-LABEL: no_extract_mul: +; CHECK: # %bb.0: +; CHECK-NEXT: leaq (%rdi,%rdi,8), %rax +; CHECK-NEXT: shlq $8, %rdi +; CHECK-NEXT: leaq (%rdi,%rdi,8), %rcx +; CHECK-NEXT: shrq $57, %rax +; CHECK-NEXT: orq %rcx, %rax +; CHECK-NEXT: retq + %lhs_mul = mul i64 %i, 9 + %rhs_mul = mul i64 %i, 2304 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +; Can't evenly factor 16 from 49 +define i64 @no_extract_udiv(i64 %i) nounwind { +; CHECK-LABEL: no_extract_udiv: +; CHECK: # %bb.0: +; CHECK-NEXT: movq %rdi, %rcx +; CHECK-NEXT: movabsq $-6148914691236517205, %rdx # imm = 0xAAAAAAAAAAAAAAAB +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: mulq %rdx +; CHECK-NEXT: movq %rdx, %rsi +; CHECK-NEXT: andq $-2, %rsi +; CHECK-NEXT: shlq $59, %rsi +; CHECK-NEXT: movabsq $5646962471543740291, %rdx # imm = 0x4E5E0A72F0539783 +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: mulq %rdx +; CHECK-NEXT: subq %rdx, %rcx +; CHECK-NEXT: shrq %rcx +; CHECK-NEXT: addq %rdx, %rcx +; CHECK-NEXT: shrq $5, %rcx +; CHECK-NEXT: leaq (%rcx,%rsi), %rax +; CHECK-NEXT: retq + %lhs_div = udiv i64 %i, 3 + %rhs_div = udiv i64 %i, 49 + %lhs_shift = shl i64 %lhs_div, 60 + %out = or i64 %lhs_shift, %rhs_div + ret i64 %out +}