Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6581,31 +6581,30 @@ } // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) + // clamp (add c1, c2) to max shift. if (N0.getOpcode() == ISD::SRA) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); + EVT ShiftSVT = ShiftVT.getScalarType(); + SmallVector ShiftValues; - auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { + auto SumOfShifts = [&](ConstantSDNode *LHS, ConstantSDNode *RHS) { APInt c1 = LHS->getAPIntValue(); APInt c2 = RHS->getAPIntValue(); zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).uge(OpSizeInBits); - }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), - DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); - - auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).ult(OpSizeInBits); + APInt Sum = c1 + c2; + unsigned ShiftSum = + Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue(); + ShiftValues.push_back(DAG.getConstant(ShiftSum, DL, ShiftSVT)); + return true; }; - if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { - SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); + if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), SumOfShifts)) { + SDValue ShiftValue; + if (VT.isVector()) + ShiftValue = DAG.getBuildVector(ShiftVT, DL, ShiftValues); + else + ShiftValue = ShiftValues[0]; + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), ShiftValue); } } Index: test/CodeGen/X86/combine-sra.ll =================================================================== --- test/CodeGen/X86/combine-sra.ll +++ test/CodeGen/X86/combine-sra.ll @@ -120,24 +120,15 @@ ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: psrad $27, %xmm1 ; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: psrad $5, %xmm2 +; SSE-NEXT: psrad $15, %xmm2 ; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $31, %xmm1 -; SSE-NEXT: psrad $1, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $10, %xmm1 -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm1[0,1,2,3],xmm0[4,5,6,7] ; SSE-NEXT: psrad $31, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7] +; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_ashr3: ; AVX: # %bb.0: ; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, %2 = ashr <4 x i32> %1,