Index: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -873,6 +873,56 @@ ISD::isBuildVectorOfConstantFPSDNodes(N); } +// Attempt to match a unary predicate against a scalar/splat constant or +// every element of a constant BUILD_VECTOR. +static bool matchUnaryPredicate(SDValue Op, + std::function Match) { + if (auto *Cst = dyn_cast(Op)) + return Match(Cst); + + if (ISD::BUILD_VECTOR != Op.getOpcode()) + return false; + + EVT SVT = Op.getValueType().getScalarType(); + for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { + auto *Cst = dyn_cast(Op.getOperand(i)); + if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst)) + return false; + } + return true; +} + +// Attempt to match a binary predicate against a pair of scalar/splat constants +// or every element of a pair of constant BUILD_VECTORs. +static bool matchBinaryPredicate( + SDValue LHS, SDValue RHS, + std::function Match) { + if (LHS.getValueType() != RHS.getValueType()) + return false; + + if (auto *LHSCst = dyn_cast(LHS)) + if (auto *RHSCst = dyn_cast(RHS)) + return Match(LHSCst, RHSCst); + + if (ISD::BUILD_VECTOR != LHS.getOpcode() || + ISD::BUILD_VECTOR != RHS.getOpcode()) + return false; + + EVT SVT = LHS.getValueType().getScalarType(); + for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { + auto *LHSCst = dyn_cast(LHS.getOperand(i)); + auto *RHSCst = dyn_cast(RHS.getOperand(i)); + if (!LHSCst || !RHSCst) + return false; + if (LHSCst->getValueType(0) != SVT || + LHSCst->getValueType(0) != RHSCst->getValueType(0)) + return false; + if (!Match(LHSCst, RHSCst)) + return false; + } + return true; +} + SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); @@ -4585,20 +4635,6 @@ return nullptr; } -// if Left + Right == Sum (constant or constant splat vector) -static bool sumMatchConstant(SDValue Left, SDValue Right, unsigned Sum, - SelectionDAG &DAG, const SDLoc &DL) { - EVT ShiftVT = Left.getValueType(); - if (ShiftVT != Right.getValueType()) return false; - - SDValue ShiftSum = DAG.FoldConstantArithmetic(ISD::ADD, DL, ShiftVT, - Left.getNode(), Right.getNode()); - if (!ShiftSum) return false; - - ConstantSDNode *CSum = isConstOrConstSplat(ShiftSum); - return CSum && CSum->getZExtValue() == Sum; -} - // MatchRotate - Handle an 'or' of two operands. If this is one of the many // idioms for rotate, and if the target supports rotation instructions, generate // a rot[lr]. @@ -4644,7 +4680,11 @@ // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - if (sumMatchConstant(LHSShiftAmt, RHSShiftAmt, EltSizeInBits, DAG, DL)) { + auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits; + }; + if (matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) { SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt); @@ -5365,7 +5405,11 @@ if (isNullConstantOrNullSplatConstant(N0)) return N0; // fold (shl x, c >= size(x)) -> undef - if (N1C && N1C->getAPIntValue().uge(OpSizeInBits)) + // NOTE: ALL vector elements must be too big to avoid partial UNDEFs. + auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { + return Val->getAPIntValue().uge(OpSizeInBits); + }; + if (matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (shl x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5392,20 +5436,29 @@ return SDValue(N, 0); // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SHL) { - if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - SDLoc DL(N); - APInt c1 = N0C1->getAPIntValue(); - APInt c2 = N1C->getAPIntValue(); + if (N0.getOpcode() == ISD::SHL) { + auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + return (c1 + c2).uge(OpSizeInBits); + }; + if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + return DAG.getConstant(0, SDLoc(N), VT); - APInt Sum = c1 + c2; - if (Sum.uge(OpSizeInBits)) - return DAG.getConstant(0, DL, VT); - - return DAG.getNode( - ISD::SHL, DL, VT, N0.getOperand(0), - DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); + auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + return (c1 + c2).ult(OpSizeInBits); + }; + if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + SDLoc DL(N); + EVT ShiftVT = N1.getValueType(); + SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); + return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum); } } Index: llvm/trunk/test/CodeGen/X86/combine-shl.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/combine-shl.ll +++ llvm/trunk/test/CodeGen/X86/combine-shl.ll @@ -37,7 +37,6 @@ ; ; AVX-LABEL: combine_vec_shl_outofrange1: ; AVX: # BB#0: -; AVX-NEXT: vpsllvd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = shl <4 x i32> %x, ret <4 x i32> %1 @@ -153,7 +152,6 @@ ; AVX-LABEL: combine_vec_shl_shl1: ; AVX: # BB#0: ; AVX-NEXT: vpsllvd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsllvd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = shl <4 x i32> %x, %2 = shl <4 x i32> %1, @@ -184,8 +182,7 @@ ; ; AVX-LABEL: combine_vec_shl_shl_zero1: ; AVX: # BB#0: -; AVX-NEXT: vpsllvd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsllvd {{.*}}(%rip), %xmm0, %xmm0 +; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = shl <4 x i32> %x, %2 = shl <4 x i32> %1,