Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -479,9 +479,6 @@ /// returns false. bool findBetterNeighborChains(StoreSDNode *St); - /// Match "(X shl/srl V1) & V2" where V2 may not be present. - bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask); - /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -4964,8 +4961,9 @@ return Tmp; // See if this is some rotate idiom. - if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) + if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) { return SDValue(Rot, 0); + } if (SDValue Load = MatchLoadCombine(N)) return Load; @@ -4977,16 +4975,19 @@ return SDValue(); } -/// Match "(X shl/srl V1) & V2" where V2 may not be present. -bool DAGCombiner::MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { - if (Op.getOpcode() == ISD::AND) { - if (DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { - Mask = Op.getOperand(1); - Op = Op.getOperand(0); - } else { - return false; - } +static SDValue stripConstantMask(SelectionDAG &DAG, SDValue Op, SDValue &Mask) { + if (Op.getOpcode() == ISD::AND && + DAG.isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { + Mask = Op.getOperand(1); + return Op.getOperand(0); } + return Op; +} + +/// Match "(X shl/srl V1) & V2" where V2 may not be present. +static bool matchRotateHalf(SelectionDAG &DAG, SDValue Op, SDValue &Shift, + SDValue &Mask) { + Op = stripConstantMask(DAG, Op, Mask); if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SHL) { Shift = Op; @@ -4996,6 +4997,121 @@ return false; } +/// Helper function for visitOR to extract the needed side of a rotate idiom +/// from a shl/srl/mul/udiv. This is meant to handle cases where +/// InstCombine merged some outside op with one of the shifts from +/// the rotate pattern. +/// \returns An empty \c SDValue if the needed shift couldn't be extracted. +/// Otherwise, returns an expansion of \p ExtractFrom based on the following +/// patterns: +/// +/// (or (mul v c0) (shrl (mul v c1) c2)): +/// expands (mul v c0) -> (shl (mul v c1) c3) +/// +/// (or (udiv v c0) (shl (udiv v c1) c2)): +/// expands (udiv v c0) -> (shrl (udiv v c1) c3) +/// +/// (or (shl v c0) (shrl (shl v c1) c2)): +/// expands (shl v c0) -> (shl (shl v c1) c3) +/// +/// (or (shrl v c0) (shl (shrl v c1) c2)): +/// expands (shrl v c0) -> (shrl (shrl v c1) c3) +/// +/// Such that in all cases, c3+c2==bitwidth(op v c1) +static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift, + SDValue ExtractFrom, SDValue &Mask, + const SDLoc &DL) { + assert(OppShift && ExtractFrom && "Empty SDValue"); + assert( + (OppShift.getOpcode() == ISD::SHL || OppShift.getOpcode() == ISD::SRL) && + "Existing shift must be valid as a rotate half"); + + // Preconditions: + // + // (or (op0 v c0) (shiftl/r (op0 v c1) c2)) + // + // op0 must be the same opcode on both sides and the same value type + ExtractFrom = stripConstantMask(DAG, ExtractFrom, Mask); + SDValue OppShiftLHS = OppShift.getOperand(0); + EVT ShiftedVT = OppShiftLHS.getValueType(); + if (OppShiftLHS.getOpcode() != ExtractFrom.getOpcode() || + ShiftedVT != ExtractFrom.getValueType()) + return SDValue(); + + // Find opcode of the needed shift to be extracted from (op0 v c0) + unsigned Opcode = ISD::DELETED_NODE; + bool IsMulOrDiv = false; + // Set Opcode and IsMulOrDiv if the extract opcode matches the needed shift + // opcode or its arithmetic (mul or udiv) variant. + auto SelectOpcode = [&](unsigned NeededShift, unsigned MulOrDivVariant) { + IsMulOrDiv = ExtractFrom.getOpcode() == MulOrDivVariant; + if (!IsMulOrDiv && ExtractFrom.getOpcode() != NeededShift) + return false; + Opcode = NeededShift; + return true; + }; + bool GotOpcode = + ((OppShift.getOpcode() == ISD::SRL && SelectOpcode(ISD::SHL, ISD::MUL)) || + (OppShift.getOpcode() == ISD::SHL && SelectOpcode(ISD::SRL, ISD::UDIV))); + // op0 must be either the needed shift opcode or the mul/udiv equivalent that + // the needed shift can be extracted from. Both cases must have v as the LHS + // input. + if (!GotOpcode || OppShiftLHS.getOperand(0) != ExtractFrom.getOperand(0)) + return SDValue(); + + // Amount of the existing shift + ConstantSDNode *OppShiftCst = isConstOrConstSplat(OppShift.getOperand(1)); + // Constant mul/udiv/shift amount from the RHS of the shift's LHS op + ConstantSDNode *OppLHSCst = isConstOrConstSplat(OppShiftLHS.getOperand(1)); + // Constant mul/udiv/shift amount from the RHS of the ExtractFrom op + ConstantSDNode *ExtractFromCst = + isConstOrConstSplat(ExtractFrom.getOperand(1)); + // Check that we have constant values + if (!OppShiftCst || !OppShiftCst->getAPIntValue() || + !OppLHSCst || !OppLHSCst->getAPIntValue() || + !ExtractFromCst || !ExtractFromCst->getAPIntValue()) + return SDValue(); + + // Compute the shift amount we need to extract to complete the rotate + const unsigned VTWidth = ShiftedVT.getScalarSizeInBits(); + APInt NeededShiftAmt = VTWidth - OppShiftCst->getAPIntValue(); + if (NeededShiftAmt.isNegative()) + return SDValue(); + // Normalize the bitwdith of the two mul/udiv/shift constant operands + APInt ExtractFromAmt = ExtractFromCst->getAPIntValue(); + APInt OppLHSAmt = OppLHSCst->getAPIntValue(); + zeroExtendToMatch(ExtractFromAmt, OppLHSAmt); + + // Now try extract the needed shift from the ExtractFrom op and see if the + // result matches up with the existing shift's LHS op + if (IsMulOrDiv) { + // Op to extract from is a mul or udiv by a constant + // Check: + // c2 / (1 << (bitwidth(op0 v c0) - c1)) == c0 + // c2 % (1 << (bitwidth(op0 v c0) - c1)) == 0 + const APInt ExtractDiv = APInt::getOneBitSet(ExtractFromAmt.getBitWidth(), + NeededShiftAmt.getZExtValue()); + APInt ResultAmt; + APInt Rem; + APInt::udivrem(ExtractFromAmt, ExtractDiv, ResultAmt, Rem); + if (Rem != 0 || ResultAmt != OppLHSAmt) + return SDValue(); + } else { + // Op to extract from is a shift by a constant + // Check: + // c2 - (bitwidth(op0 v c0) - c1) == c0 + if (OppLHSAmt != ExtractFromAmt - NeededShiftAmt.zextOrTrunc( + ExtractFromAmt.getBitWidth())) + return SDValue(); + } + + // Return the expanded shift op that should allow a rotate to be formed + EVT ShiftVT = OppShift.getOperand(1).getValueType(); + EVT ResVT = ExtractFrom.getValueType(); + SDValue NewShiftNode = DAG.getConstant(NeededShiftAmt, DL, ShiftVT); + return DAG.getNode(Opcode, DL, ResVT, OppShiftLHS, NewShiftNode); +} + // Return true if we can prove that, whenever Neg and Pos are both in the // range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: @@ -5159,17 +5275,48 @@ } } + bool NeedRotLHS = true; + bool NeedRotRHS = true; + // Match "(X shl/srl V1) & V2" where V2 may not be present. SDValue LHSShift; // The shift. SDValue LHSMask; // AND value if any. - if (!MatchRotateHalf(LHS, LHSShift, LHSMask)) - return nullptr; // Not part of a rotate. + if (matchRotateHalf(DAG, LHS, LHSShift, LHSMask)) + NeedRotLHS = false; SDValue RHSShift; // The shift. SDValue RHSMask; // AND value if any. - if (!MatchRotateHalf(RHS, RHSShift, RHSMask)) - return nullptr; // Not part of a rotate. + if (matchRotateHalf(DAG, RHS, RHSShift, RHSMask)) + NeedRotRHS = false; + + // If neither side matched a rotate half, bail + if (NeedRotLHS && NeedRotRHS) + return nullptr; + + // InstCombine may have combined a constant shl, srl, mul, or udiv with one + // side of the rotate, so try to handle that here. In all cases we need to + // pass the matched shift from the opposite side to compute the opcode and + // needed shift amount to extract. We still want to do this if both sides + // matched a rotate half because one half may be a potential overshift that + // can be broken down (ie if InstCombine merged two shl or srl ops into a + // single one). + + // Have LHS side of the rotate, try to extract the needed shift from the RHS + if (!NeedRotLHS) + if (SDValue NewRHSShift = + extractShiftForRotate(DAG, LHSShift, RHS, RHSMask, DL)) + RHSShift = NewRHSShift; + // Have RHS side of the rotate, try to extract the needed shift from the LHS + if (!NeedRotRHS) + if (SDValue NewLHSShift = + extractShiftForRotate(DAG, RHSShift, LHS, LHSMask, DL)) + LHSShift = NewLHSShift; + // If a side is still missing, nothing else we can do + if (!RHSShift || !LHSShift) + return nullptr; + // At this point we've matched or extracted a shift op on each side + if (LHSShift.getOperand(0) != RHSShift.getOperand(0)) return nullptr; // Not shifting the same value. Index: test/CodeGen/AArch64/rotate-extract.ll =================================================================== --- test/CodeGen/AArch64/rotate-extract.ll +++ test/CodeGen/AArch64/rotate-extract.ll @@ -0,0 +1,145 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-unknown-unknown | FileCheck %s + +; Check that under certain conditions we can factor out a rotate +; from the following idioms: +; (a*c0) >> s1 | (a*c1) +; (a/c0) << s1 | (a/c1) +; This targets cases where instcombine has folded a shl/srl/mul/udiv +; with one of the shifts from the rotate idiom + +define i64 @ror_extract_shl(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_shl: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl x8, x0, #3 +; CHECK-NEXT: ror x0, x8, #57 +; CHECK-NEXT: ret + %lhs_mul = shl i64 %i, 3 + %rhs_mul = shl i64 %i, 10 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +define i32 @ror_extract_shrl(i32 %i) nounwind { +; CHECK-LABEL: ror_extract_shrl: +; CHECK: // %bb.0: +; CHECK-NEXT: lsr w8, w0, #3 +; CHECK-NEXT: ror w0, w8, #4 +; CHECK-NEXT: ret + %lhs_div = lshr i32 %i, 7 + %rhs_div = lshr i32 %i, 3 + %rhs_shift = shl i32 %rhs_div, 28 + %out = or i32 %lhs_div, %rhs_shift + ret i32 %out +} + +define i32 @ror_extract_mul(i32 %i) nounwind { +; CHECK-LABEL: ror_extract_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: add w8, w0, w0, lsl #3 +; CHECK-NEXT: ror w0, w8, #25 +; CHECK-NEXT: ret + %lhs_mul = mul i32 %i, 9 + %rhs_mul = mul i32 %i, 1152 + %lhs_shift = lshr i32 %lhs_mul, 25 + %out = or i32 %lhs_shift, %rhs_mul + ret i32 %out +} + +define i64 @ror_extract_udiv(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_udiv: +; CHECK: // %bb.0: +; CHECK-NEXT: mov x8, #-6148914691236517206 +; CHECK-NEXT: movk x8, #43691 +; CHECK-NEXT: umulh x8, x0, x8 +; CHECK-NEXT: lsr x8, x8, #1 +; CHECK-NEXT: ror x0, x8, #4 +; CHECK-NEXT: ret + %lhs_div = udiv i64 %i, 3 + %rhs_div = udiv i64 %i, 48 + %lhs_shift = shl i64 %lhs_div, 60 + %out = or i64 %lhs_shift, %rhs_div + ret i64 %out +} + +define i64 @ror_extract_mul_with_mask(i64 %i) nounwind { +; CHECK-LABEL: ror_extract_mul_with_mask: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x0, lsl #3 +; CHECK-NEXT: ror x8, x8, #57 +; CHECK-NEXT: and x0, x8, #0xff +; CHECK-NEXT: ret + %lhs_mul = mul i64 %i, 1152 + %rhs_mul = mul i64 %i, 9 + %lhs_and = and i64 %lhs_mul, 160 + %rhs_shift = lshr i64 %rhs_mul, 57 + %out = or i64 %lhs_and, %rhs_shift + ret i64 %out +} + +; Result would undershift +define i64 @no_extract_shl(i64 %i) nounwind { +; CHECK-LABEL: no_extract_shl: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl x8, x0, #10 +; CHECK-NEXT: bfxil x8, x0, #52, #7 +; CHECK-NEXT: mov x0, x8 +; CHECK-NEXT: ret + %lhs_mul = shl i64 %i, 5 + %rhs_mul = shl i64 %i, 10 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +; Result would overshift +define i32 @no_extract_shrl(i32 %i) nounwind { +; CHECK-LABEL: no_extract_shrl: +; CHECK: // %bb.0: +; CHECK-NEXT: lsr w8, w0, #3 +; CHECK-NEXT: lsr w0, w0, #9 +; CHECK-NEXT: bfi w0, w8, #28, #4 +; CHECK-NEXT: ret + %lhs_div = lshr i32 %i, 3 + %rhs_div = lshr i32 %i, 9 + %lhs_shift = shl i32 %lhs_div, 28 + %out = or i32 %lhs_shift, %rhs_div + ret i32 %out +} + +; Can factor 128 from 2304, but result is 18 instead of 9 +define i64 @no_extract_mul(i64 %i) nounwind { +; CHECK-LABEL: no_extract_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x0, lsl #3 +; CHECK-NEXT: lsr x0, x8, #57 +; CHECK-NEXT: bfi x0, x8, #8, #56 +; CHECK-NEXT: ret + %lhs_mul = mul i64 %i, 2304 + %rhs_mul = mul i64 %i, 9 + %rhs_shift = lshr i64 %rhs_mul, 57 + %out = or i64 %lhs_mul, %rhs_shift + ret i64 %out +} + +; Can't evenly factor 16 from 49 +define i32 @no_extract_udiv(i32 %i) nounwind { +; CHECK-LABEL: no_extract_udiv: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #43691 +; CHECK-NEXT: mov w9, #33437 +; CHECK-NEXT: movk w8, #43690, lsl #16 +; CHECK-NEXT: movk w9, #21399, lsl #16 +; CHECK-NEXT: umull x8, w0, w8 +; CHECK-NEXT: umull x9, w0, w9 +; CHECK-NEXT: lsr x8, x8, #33 +; CHECK-NEXT: lsr x9, x9, #32 +; CHECK-NEXT: extr w0, w8, w9, #4 +; CHECK-NEXT: ret + %lhs_div = udiv i32 %i, 3 + %rhs_div = udiv i32 %i, 49 + %lhs_shift = shl i32 %lhs_div, 28 + %out = or i32 %lhs_shift, %rhs_div + ret i32 %out +} Index: test/CodeGen/X86/rotate-extract.ll =================================================================== --- test/CodeGen/X86/rotate-extract.ll +++ test/CodeGen/X86/rotate-extract.ll @@ -0,0 +1,163 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s + +; Check that under certain conditions we can factor out a rotate +; from the following idioms: +; (a*c0) >> s1 | (a*c1) +; (a/c0) << s1 | (a/c1) +; This targets cases where instcombine has folded a shl/srl/mul/udiv +; with one of the shifts from the rotate idiom + +define i64 @rolq_extract_shl(i64 %i) nounwind { +; CHECK-LABEL: rolq_extract_shl: +; CHECK: # %bb.0: +; CHECK-NEXT: leaq (,%rdi,8), %rax +; CHECK-NEXT: rolq $7, %rax +; CHECK-NEXT: retq + %lhs_mul = shl i64 %i, 3 + %rhs_mul = shl i64 %i, 10 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +define i16 @rolw_extract_shrl(i16 %i) nounwind { +; CHECK-LABEL: rolw_extract_shrl: +; CHECK: # %bb.0: +; CHECK-NEXT: movzwl %di, %eax +; CHECK-NEXT: shrl $3, %eax +; CHECK-NEXT: rolw $12, %ax +; CHECK-NEXT: # kill: def $ax killed $ax killed $eax +; CHECK-NEXT: retq + %lhs_div = lshr i16 %i, 7 + %rhs_div = lshr i16 %i, 3 + %rhs_shift = shl i16 %rhs_div, 12 + %out = or i16 %lhs_div, %rhs_shift + ret i16 %out +} + +define i32 @roll_extract_mul(i32 %i) nounwind { +; CHECK-LABEL: roll_extract_mul: +; CHECK: # %bb.0: +; CHECK-NEXT: # kill: def $edi killed $edi def $rdi +; CHECK-NEXT: leal (%rdi,%rdi,8), %eax +; CHECK-NEXT: roll $7, %eax +; CHECK-NEXT: retq + %lhs_mul = mul i32 %i, 9 + %rhs_mul = mul i32 %i, 1152 + %lhs_shift = lshr i32 %lhs_mul, 25 + %out = or i32 %lhs_shift, %rhs_mul + ret i32 %out +} + +define i8 @rolb_extract_udiv(i8 %i) nounwind { +; CHECK-LABEL: rolb_extract_udiv: +; CHECK: # %bb.0: +; CHECK-NEXT: movzbl %dil, %eax +; CHECK-NEXT: imull $171, %eax, %eax +; CHECK-NEXT: shrl $9, %eax +; CHECK-NEXT: rolb $4, %al +; CHECK-NEXT: # kill: def $al killed $al killed $eax +; CHECK-NEXT: retq + %lhs_div = udiv i8 %i, 3 + %rhs_div = udiv i8 %i, 48 + %lhs_shift = shl i8 %lhs_div, 4 + %out = or i8 %lhs_shift, %rhs_div + ret i8 %out +} + +define i64 @rolq_extract_mul_with_mask(i64 %i) nounwind { +; CHECK-LABEL: rolq_extract_mul_with_mask: +; CHECK: # %bb.0: +; CHECK-NEXT: leaq (%rdi,%rdi,8), %rax +; CHECK-NEXT: rolq $7, %rax +; CHECK-NEXT: movzbl %al, %eax +; CHECK-NEXT: retq + %lhs_mul = mul i64 %i, 1152 + %rhs_mul = mul i64 %i, 9 + %lhs_and = and i64 %lhs_mul, 160 + %rhs_shift = lshr i64 %rhs_mul, 57 + %out = or i64 %lhs_and, %rhs_shift + ret i64 %out +} + +; Result would undershift +define i64 @no_extract_shl(i64 %i) nounwind { +; CHECK-LABEL: no_extract_shl: +; CHECK: # %bb.0: +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: shlq $5, %rax +; CHECK-NEXT: shlq $10, %rdi +; CHECK-NEXT: shrq $57, %rax +; CHECK-NEXT: leaq (%rax,%rdi), %rax +; CHECK-NEXT: retq + %lhs_mul = shl i64 %i, 5 + %rhs_mul = shl i64 %i, 10 + %lhs_shift = lshr i64 %lhs_mul, 57 + %out = or i64 %lhs_shift, %rhs_mul + ret i64 %out +} + +; Result would overshift +define i32 @no_extract_shrl(i32 %i) nounwind { +; CHECK-LABEL: no_extract_shrl: +; CHECK: # %bb.0: +; CHECK-NEXT: # kill: def $edi killed $edi def $rdi +; CHECK-NEXT: movl %edi, %eax +; CHECK-NEXT: andl $-8, %eax +; CHECK-NEXT: shll $25, %eax +; CHECK-NEXT: shrl $9, %edi +; CHECK-NEXT: leal (%rdi,%rax), %eax +; CHECK-NEXT: retq + %lhs_div = lshr i32 %i, 3 + %rhs_div = lshr i32 %i, 9 + %lhs_shift = shl i32 %lhs_div, 28 + %out = or i32 %lhs_shift, %rhs_div + ret i32 %out +} + +; Can factor 128 from 2304, but result is 18 instead of 9 +define i16 @no_extract_mul(i16 %i) nounwind { +; CHECK-LABEL: no_extract_mul: +; CHECK: # %bb.0: +; CHECK-NEXT: # kill: def $edi killed $edi def $rdi +; CHECK-NEXT: leal (%rdi,%rdi,8), %eax +; CHECK-NEXT: # kill: def $edi killed $edi killed $rdi def $rdi +; CHECK-NEXT: shll $8, %edi +; CHECK-NEXT: leal (%rdi,%rdi,8), %ecx +; CHECK-NEXT: movzwl %ax, %eax +; CHECK-NEXT: shrl $9, %eax +; CHECK-NEXT: orl %ecx, %eax +; CHECK-NEXT: # kill: def $ax killed $ax killed $eax +; CHECK-NEXT: retq + %lhs_mul = mul i16 %i, 2304 + %rhs_mul = mul i16 %i, 9 + %rhs_shift = lshr i16 %rhs_mul, 9 + %out = or i16 %lhs_mul, %rhs_shift + ret i16 %out +} + +; Can't evenly factor 16 from 49 +define i8 @no_extract_udiv(i8 %i) nounwind { +; CHECK-LABEL: no_extract_udiv: +; CHECK: # %bb.0: +; CHECK-NEXT: movzbl %dil, %eax +; CHECK-NEXT: imull $171, %eax, %ecx +; CHECK-NEXT: shrl $8, %ecx +; CHECK-NEXT: shlb $3, %cl +; CHECK-NEXT: andb $-16, %cl +; CHECK-NEXT: imull $79, %eax, %edx +; CHECK-NEXT: shrl $8, %edx +; CHECK-NEXT: subb %dl, %al +; CHECK-NEXT: shrb %al +; CHECK-NEXT: addb %dl, %al +; CHECK-NEXT: shrb $5, %al +; CHECK-NEXT: orb %cl, %al +; CHECK-NEXT: # kill: def $al killed $al killed $eax +; CHECK-NEXT: retq + %lhs_div = udiv i8 %i, 3 + %rhs_div = udiv i8 %i, 49 + %lhs_shift = shl i8 %lhs_div,4 + %out = or i8 %lhs_shift, %rhs_div + ret i8 %out +}