diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -21260,7 +21260,8 @@ // ty1 extract_vector(ty2 splat(V))) -> ty1 splat(V) if (V.getOpcode() == ISD::SPLAT_VECTOR) if (DAG.isConstantValueOfAnyType(V.getOperand(0)) || V.hasOneUse()) - return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0)); + if (!LegalOperations || TLI.isOperationLegal(ISD::SPLAT_VECTOR, NVT)) + return DAG.getSplatVector(NVT, SDLoc(N), V.getOperand(0)); // Try to move vector bitcast after extract_subv by scaling extraction index: // extract_subv (bitcast X), Index --> bitcast (extract_subv X, Index') diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -871,6 +871,7 @@ setTargetDAGCombine(ISD::VECTOR_SPLICE); setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); setTargetDAGCombine(ISD::CONCAT_VECTORS); + setTargetDAGCombine(ISD::EXTRACT_SUBVECTOR); setTargetDAGCombine(ISD::INSERT_SUBVECTOR); setTargetDAGCombine(ISD::STORE); if (Subtarget->supportsAddressTopByteIgnored()) @@ -14472,6 +14473,29 @@ RHS)); } +static SDValue +performExtractSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) + return SDValue(); + + SDValue V = N->getOperand(0); + + // NOTE: This combine exists in DAGCombiner, but that version's legality check + // blocks this combine because the non-const case requires custom lowering. + // + // ty1 extract_vector(ty2 splat(const))) -> ty1 splat(const) + if (V.getOpcode() == ISD::SPLAT_VECTOR) + if (isa(V.getOperand(0))) + return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V.getOperand(0)); + + return SDValue(); +} + static SDValue performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -18181,6 +18205,8 @@ return performSignExtendInRegCombine(N, DCI, DAG); case ISD::CONCAT_VECTORS: return performConcatVectorsCombine(N, DCI, DAG); + case ISD::EXTRACT_SUBVECTOR: + return performExtractSubvectorCombine(N, DCI, DAG); case ISD::INSERT_SUBVECTOR: return performInsertSubvectorCombine(N, DCI, DAG); case ISD::SELECT: