diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -4750,6 +4750,28 @@ MVT EltVT = OVT.getVectorElementType(); MVT NewEltVT = NVT.getVectorElementType(); + SDValue Vec = Node->getOperand(0); + SDValue Idx = Node->getOperand(1); + SDLoc SL(Node); + + if (OVT.getSizeInBits() != NVT.getSizeInBits()) { + assert(NVT.isVector() && + OVT.getVectorElementCount() == NVT.getVectorElementCount() && + "Invalid promote type for extract_vector_elt!"); + assert(NewEltVT.bitsGT(EltVT) && "Cannot promote to a smaller type!"); + + // Ensure the extract VT is not smaller than its operand element type. + EVT ResVT = Node->getValueType(0); + EVT ExtactVT = ResVT.bitsGT(NewEltVT) ? ResVT : NewEltVT; + + // Perform the extract using a promoted input vector. + SDValue V = DAG.getNode(ISD::ANY_EXTEND, SL, NVT, Vec); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ExtactVT, V, Idx); + + Results.push_back(DAG.getNode(ISD::TRUNCATE, SL, ResVT, Elt)); + break; + } + // Handle bitcasts to a different vector type with the same total bit size. // // e.g. v2i64 = extract_vector_elt x:v2i64, y:i32 @@ -4768,13 +4790,10 @@ MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT); unsigned NewEltsPerOldElt = MidVT.getVectorNumElements(); - SDValue Idx = Node->getOperand(1); EVT IdxVT = Idx.getValueType(); - SDLoc SL(Node); SDValue Factor = DAG.getConstant(NewEltsPerOldElt, SL, IdxVT); SDValue NewBaseIdx = DAG.getNode(ISD::MUL, SL, IdxVT, Idx, Factor); - - SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Node->getOperand(0)); + SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Vec); SmallVector NewOps; for (unsigned I = 0; I < NewEltsPerOldElt; ++I) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3327,6 +3327,9 @@ SDValue InVec = Op.getOperand(0); SDValue EltNo = Op.getOperand(1); EVT VecVT = InVec.getValueType(); + // computeKnownBits not yet impemented for scalable vectors. + if (VecVT.isScalableVector()) + break; const unsigned EltBitWidth = VecVT.getScalarSizeInBits(); const unsigned NumSrcElts = VecVT.getVectorNumElements(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1028,12 +1028,16 @@ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); + MVT PromotedVT = getPromotedVTForPredicate(VT); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Promote); + AddPromotedToType(ISD::EXTRACT_VECTOR_ELT, VT, PromotedVT); + // There are no legal MVT::nxv16f## based types. if (VT != MVT::nxv16i1) { setOperationAction(ISD::SINT_TO_FP, VT, Promote); - AddPromotedToType(ISD::SINT_TO_FP, VT, getPromotedVTForPredicate(VT)); + AddPromotedToType(ISD::SINT_TO_FP, VT, PromotedVT); setOperationAction(ISD::UINT_TO_FP, VT, Promote); - AddPromotedToType(ISD::UINT_TO_FP, VT, getPromotedVTForPredicate(VT)); + AddPromotedToType(ISD::UINT_TO_FP, VT, PromotedVT); } } diff --git a/llvm/test/CodeGen/AArch64/sve-extract-element.ll b/llvm/test/CodeGen/AArch64/sve-extract-element.ll --- a/llvm/test/CodeGen/AArch64/sve-extract-element.ll +++ b/llvm/test/CodeGen/AArch64/sve-extract-element.ll @@ -482,4 +482,53 @@ ret i64 %c } +define i1 @test_lane0_16xi1( %a) #0 { +; CHECK-LABEL: test_lane0_16xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1 +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 0 + ret i1 %b +} + +define i1 @test_lane9_8xi1( %a) #0 { +; CHECK-LABEL: test_lane9_8xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.h, p0/z, #1 // =0x1 +; CHECK-NEXT: mov z0.h, z0.h[9] +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 9 + ret i1 %b +} + +define i1 @test_lanex_4xi1( %a, i32 %x) #0 { +; CHECK-LABEL: test_lanex_4xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-NEXT: sxtw x8, w0 +; CHECK-NEXT: whilels p1.s, xzr, x8 +; CHECK-NEXT: mov z0.s, p0/z, #1 // =0x1 +; CHECK-NEXT: lastb w8, p1, z0.s +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 %x + ret i1 %b +} + +define i1 @test_lane4_2xi1( %a) #0 { +; CHECK-LABEL: test_lane4_2xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, p0/z, #1 // =0x1 +; CHECK-NEXT: mov z0.d, z0.d[4] +; CHECK-NEXT: fmov x8, d0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 4 + ret i1 %b +} + attributes #0 = { "target-features"="+sve" }