Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5652,32 +5652,17 @@ } // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) - if (N0.getOpcode() == ISD::SRA) { + // clamp (add c1, c2) to max shift. + if (N0.getOpcode() == ISD::SRA && + isConstantOrConstantVector(N1, /* NoOpaques */ true) && + isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true) && + N1.getValueType() == N0.getOperand(1).getValueType()) { SDLoc DL(N); EVT ShiftVT = N1.getValueType(); - - auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).uge(OpSizeInBits); - }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), - DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); - - auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, - ConstantSDNode *RHS) { - APInt c1 = LHS->getAPIntValue(); - APInt c2 = RHS->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - return (c1 + c2).ult(OpSizeInBits); - }; - if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { - SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); - } + SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); + SDValue Limit = DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT); + SDValue Clamp = DAG.getNode(ISD::UMIN, DL, ShiftVT, Sum, Limit); + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Clamp); } // fold (sra (shl X, m), (sub result_size, n)) Index: test/CodeGen/X86/combine-sra.ll =================================================================== --- test/CodeGen/X86/combine-sra.ll +++ test/CodeGen/X86/combine-sra.ll @@ -131,24 +131,15 @@ ; SSE-NEXT: movdqa %xmm0, %xmm1 ; SSE-NEXT: psrad $27, %xmm1 ; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: psrad $5, %xmm2 +; SSE-NEXT: psrad $15, %xmm2 ; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $31, %xmm1 -; SSE-NEXT: psrad $1, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $10, %xmm1 -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm1[0,1,2,3],xmm0[4,5,6,7] ; SSE-NEXT: psrad $31, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7] +; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_ashr3: ; AVX: # BB#0: ; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, %2 = ashr <4 x i32> %1,