Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -1101,7 +1101,7 @@ // (result) is 128-bit but the source is 256-bit wide. for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64, MVT::v4f32, MVT::v2f64 }) { - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); } // Custom lower several nodes for 256-bit types. @@ -1381,12 +1381,15 @@ setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); } + + setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v1i1, Legal); + // Extract subvector is special because the value type // (result) is 256-bit but the source is 512-bit wide. - // 128-bit was made Custom under AVX1. + // 128-bit was made Legal under AVX1. for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64, - MVT::v8f32, MVT::v4f64, MVT::v1i1 }) - setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + MVT::v8f32, MVT::v4f64 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); for (auto VT : { MVT::v2i1, MVT::v4i1, MVT::v8i1, MVT::v16i1, MVT::v32i1, MVT::v64i1 }) setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); @@ -14548,41 +14551,24 @@ // upper bits of a vector. static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, SelectionDAG &DAG) { - assert(Subtarget.hasAVX() && "EXTRACT_SUBVECTOR requires AVX"); - SDLoc dl(Op); SDValue In = Op.getOperand(0); SDValue Idx = Op.getOperand(1); - unsigned IdxVal = cast(Idx)->getZExtValue(); MVT ResVT = Op.getSimpleValueType(); // When v1i1 is legal a scalarization of a vselect with a vXi1 Cond // would result with: v1i1 = extract_subvector(vXi1, idx). // Lower these into extract_vector_elt which is already selectable. - if (ResVT == MVT::v1i1) { - assert(Subtarget.hasAVX512() && - "Boolean EXTRACT_SUBVECTOR requires AVX512"); - - MVT EltVT = ResVT.getVectorElementType(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - MVT LegalVT = - (TLI.getTypeToTransformTo(*DAG.getContext(), EltVT)).getSimpleVT(); - SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LegalVT, In, Idx); - return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, ResVT, Res); - } - - assert((In.getSimpleValueType().is256BitVector() || - In.getSimpleValueType().is512BitVector()) && - "Can only extract from 256-bit or 512-bit vectors"); - - // If the input is a buildvector just emit a smaller one. - unsigned ElemsPerChunk = ResVT.getVectorNumElements(); - if (In.getOpcode() == ISD::BUILD_VECTOR) - return DAG.getBuildVector( - ResVT, dl, makeArrayRef(In->op_begin() + IdxVal, ElemsPerChunk)); + assert(ResVT == MVT::v1i1); + assert(Subtarget.hasAVX512() && + "Boolean EXTRACT_SUBVECTOR requires AVX512"); - // Everything else is legal. - return Op; + MVT EltVT = ResVT.getVectorElementType(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + MVT LegalVT = + (TLI.getTypeToTransformTo(*DAG.getContext(), EltVT)).getSimpleVT(); + SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LegalVT, In, Idx); + return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, ResVT, Res); } // Lower a node with an INSERT_SUBVECTOR opcode. This may result in a @@ -35692,16 +35678,23 @@ return SDValue(); MVT OpVT = N->getSimpleValueType(0); + SDValue InVec = N->getOperand(0); + unsigned IdxVal = cast(N->getOperand(1))->getZExtValue(); - if (ISD::isBuildVectorAllZeros(N->getOperand(0).getNode())) + if (ISD::isBuildVectorAllZeros(InVec.getNode())) return getZeroVector(OpVT, Subtarget, DAG, SDLoc(N)); - if (ISD::isBuildVectorAllOnes(N->getOperand(0).getNode())) { + if (ISD::isBuildVectorAllOnes(InVec.getNode())) { if (OpVT.getScalarType() == MVT::i1) return DAG.getConstant(1, SDLoc(N), OpVT); return getZeroVector(OpVT, Subtarget, DAG, SDLoc(N)); } + if (InVec.getOpcode() == ISD::BUILD_VECTOR) + return DAG.getBuildVector(OpVT, SDLoc(N), + makeArrayRef(InVec.getNode()->op_begin() + IdxVal, + OpVT.getVectorNumElements())); + return SDValue(); }