Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -85,10 +85,7 @@ /// If N is a BUILD_VECTOR node whose elements are all the same constant or /// undefined, return true and return the constant value in \p SplatValue. - /// This sets \p SplatValue to the smallest possible splat unless AllowShrink - /// is set to false. - bool isConstantSplatVector(const SDNode *N, APInt &SplatValue, - bool AllowShrink = true); + bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); /// Return true if the specified node is a BUILD_VECTOR where all of the /// elements are ~0 or undef. Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2588,6 +2588,12 @@ N0IsConst = ISD::isConstantSplatVector(N0.getNode(), ConstValue0); N1IsConst = ISD::isConstantSplatVector(N1.getNode(), ConstValue1); + assert((!N0IsConst || + ConstValue0.getBitWidth() == VT.getScalarSizeInBits()) && + "Splat APInt should be element width"); + assert((!N1IsConst || + ConstValue1.getBitWidth() == VT.getScalarSizeInBits()) && + "Splat APInt should be element width"); } else { N0IsConst = isa(N0); if (N0IsConst) { @@ -2613,12 +2619,8 @@ // fold (mul x, 0) -> 0 if (N1IsConst && ConstValue1.isNullValue()) return N1; - // We require a splat of the entire scalar bit width for non-contiguous - // bit patterns. - bool IsFullSplat = - ConstValue1.getBitWidth() == VT.getScalarSizeInBits(); // fold (mul x, 1) -> x - if (N1IsConst && ConstValue1.isOneValue() && IsFullSplat) + if (N1IsConst && ConstValue1.isOneValue()) return N0; if (SDValue NewSel = foldBinOpIntoSelect(N)) @@ -2643,8 +2645,7 @@ return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); } // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c - if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2() && - IsFullSplat) { + if (N1IsConst && !N1IsOpaqueConst && (-ConstValue1).isPowerOf2()) { unsigned Log2Val = (-ConstValue1).logBase2(); SDLoc DL(N); // FIXME: If the input is something that is easily negated (e.g. a Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -125,8 +125,7 @@ // ISD Namespace //===----------------------------------------------------------------------===// -bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal, - bool AllowShrink) { +bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { auto *BV = dyn_cast(N); if (!BV) return false; @@ -135,10 +134,9 @@ unsigned SplatBitSize; bool HasUndefs; unsigned EltSize = N->getValueType(0).getVectorElementType().getSizeInBits(); - unsigned MinSplatBits = AllowShrink ? 0 : EltSize; return BV->isConstantSplat(SplatVal, SplatUndef, SplatBitSize, HasUndefs, - MinSplatBits) && - EltSize >= SplatBitSize; + EltSize) && + EltSize == SplatBitSize; } // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -29567,8 +29567,7 @@ // In SetLT case, The second operand of the comparison can be either 1 or 0. APInt SplatVal; if ((CC == ISD::SETLT) && - !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal, - /*AllowShrink*/false) && + !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && SplatVal.isOneValue()) || (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) return false; @@ -32084,8 +32083,7 @@ return SDValue(); APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal, - /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || !SplatVal.isMask()) return SDValue(); @@ -32669,8 +32667,7 @@ "Unexpected types for truncate operation"); APInt C; - if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C, - /*AllowShrink*/false)) { + if (ISD::isConstantSplatVector(In.getOperand(1).getNode(), C)) { // C should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according // the element size of the destination type. return C.isMask(VT.getScalarSizeInBits()) ? In.getOperand(0) : @@ -35377,7 +35374,7 @@ SDNode *N1 = N->getOperand(1).getNode(); APInt SplatVal; - if (!ISD::isConstantSplatVector(N1, SplatVal, /*AllowShrink*/false) || + if (!ISD::isConstantSplatVector(N1, SplatVal) || !SplatVal.isOneValue()) return SDValue();