Index: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h +++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2487,17 +2487,18 @@ /// Attempt to match a unary predicate against a scalar/splat constant or /// every element of a constant BUILD_VECTOR. - /// If AllowUndef is true, then UNDEF element will pass nullptr to Match. + /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match. bool matchUnaryPredicate(SDValue Op, std::function Match, bool AllowUndefs = false); /// Attempt to match a binary predicate against a pair of scalar/splat /// constants or every element of a pair of constant BUILD_VECTORs. + /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match. bool matchBinaryPredicate( SDValue LHS, SDValue RHS, - std::function Match); - + std::function Match, + bool AllowUndefs = false); } // end namespace ISD } // end namespace llvm Index: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -296,26 +296,33 @@ bool ISD::matchBinaryPredicate( SDValue LHS, SDValue RHS, - std::function Match) { + std::function Match, + bool AllowUndefs) { if (LHS.getValueType() != RHS.getValueType()) return false; + // TODO: Add support for scalar UNDEF cases? if (auto *LHSCst = dyn_cast(LHS)) if (auto *RHSCst = dyn_cast(RHS)) return Match(LHSCst, RHSCst); + // TODO: Add support for vector UNDEF cases? if (ISD::BUILD_VECTOR != LHS.getOpcode() || ISD::BUILD_VECTOR != RHS.getOpcode()) return false; EVT SVT = LHS.getValueType().getScalarType(); for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { - auto *LHSCst = dyn_cast(LHS.getOperand(i)); - auto *RHSCst = dyn_cast(RHS.getOperand(i)); - if (!LHSCst || !RHSCst) + SDValue LHSOp = LHS.getOperand(i); + SDValue RHSOp = RHS.getOperand(i); + bool LHSUndef = AllowUndefs && LHSOp.isUndef(); + bool RHSUndef = AllowUndefs && RHSOp.isUndef(); + auto *LHSCst = dyn_cast(LHSOp); + auto *RHSCst = dyn_cast(RHSOp); + if ((!LHSCst && !LHSUndef) || (!RHSCst && !RHSUndef)) return false; - if (LHSCst->getValueType(0) != SVT || - LHSCst->getValueType(0) != RHSCst->getValueType(0)) + if (LHSOp.getValueType() != SVT || + LHSOp.getValueType() != RHSOp.getValueType()) return false; if (!Match(LHSCst, RHSCst)) return false;