diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1829,7 +1829,8 @@ /// for \p DemandedElts. /// /// NOTE: The function will return true for a demanded splat of UNDEF values. - bool isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts); + bool isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts, + unsigned Depth = 0); /// Test whether \p V has a splatted value. bool isSplatValue(SDValue V, bool AllowUndefs = false); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2355,13 +2355,16 @@ /// sense to specify which elements are demanded or undefined, therefore /// they are simply ignored. bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, - APInt &UndefElts) { + APInt &UndefElts, unsigned Depth) { EVT VT = V.getValueType(); assert(VT.isVector() && "Vector type expected"); if (!VT.isScalableVector() && !DemandedElts) return false; // No demanded elts, better to assume we don't know anything. + if (Depth >= MaxRecursionDepth) + return false; // Limit search depth. + // Deal with some common cases here that work for both fixed and scalable // vector types. switch (V.getOpcode()) { @@ -2376,8 +2379,8 @@ APInt UndefLHS, UndefRHS; SDValue LHS = V.getOperand(0); SDValue RHS = V.getOperand(1); - if (isSplatValue(LHS, DemandedElts, UndefLHS) && - isSplatValue(RHS, DemandedElts, UndefRHS)) { + if (isSplatValue(LHS, DemandedElts, UndefLHS, Depth + 1) && + isSplatValue(RHS, DemandedElts, UndefRHS, Depth + 1)) { UndefElts = UndefLHS | UndefRHS; return true; } @@ -2386,7 +2389,7 @@ case ISD::TRUNCATE: case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: - return isSplatValue(V.getOperand(0), DemandedElts, UndefElts); + return isSplatValue(V.getOperand(0), DemandedElts, UndefElts, Depth + 1); } // We don't support other cases than those above for scalable vectors at @@ -2441,7 +2444,7 @@ unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); APInt UndefSrcElts; APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); - if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts)) { + if (isSplatValue(Src, DemandedSrcElts, UndefSrcElts, Depth + 1)) { UndefElts = UndefSrcElts.extractBits(NumElts, Idx); return true; }