diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3308,6 +3308,9 @@ SDValue InVec = Op.getOperand(0); SDValue EltNo = Op.getOperand(1); EVT VecVT = InVec.getValueType(); + // computeKnownBits not yet impemented for scalable vectors. + if (VecVT.isScalableVector()) + break; const unsigned EltBitWidth = VecVT.getScalarSizeInBits(); const unsigned NumSrcElts = VecVT.getVectorNumElements(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -960,6 +960,7 @@ setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Legal); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); } } } @@ -9030,14 +9031,36 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!"); + EVT VT = Op.getOperand(0).getValueType(); + + if (VT.getScalarType() == MVT::i1) { + // We can't directly extract from an SVE predicate; extend it first. + // (This isn't the only possible lowering, but it's straightforward.) + MVT ScalarVT; + auto Count = VT.getVectorElementCount(); + if (Count.getKnownMinValue() == 2) + ScalarVT = MVT::i64; + else if (Count.getKnownMinValue() == 4) + ScalarVT = MVT::i32; + else if (Count.getKnownMinValue() == 8) + ScalarVT = MVT::i16; + else + ScalarVT = MVT::i8; + EVT VectorVT = EVT::getVectorVT(*DAG.getContext(), ScalarVT, Count); + SDLoc DL(Op); + SDValue Extend = DAG.getNode(ISD::SIGN_EXTEND, DL, VectorVT, + Op.getOperand(0)); + MVT ExtractTy = ScalarVT == MVT::i64 ? MVT::i64 : MVT::i32; + SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy, + Extend, Op.getOperand(1)); + return DAG.getSExtOrTrunc(Extract, DL, Op.getValueType()); + } // Check for non-constant or out of range lane. - EVT VT = Op.getOperand(0).getValueType(); ConstantSDNode *CI = dyn_cast(Op.getOperand(1)); if (!CI || CI->getZExtValue() >= VT.getVectorNumElements()) return SDValue(); - // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 || diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2153,6 +2153,22 @@ (DUP_ZR_D $index)), $src)>; + // Extract element from vector with zero index + def : Pat<(i32 (vector_extract (nxv16i8 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, ssub)>; + def : Pat<(i32 (vector_extract (nxv8i16 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, ssub)>; + def : Pat<(i32 (vector_extract (nxv4i32 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, ssub)>; + def : Pat<(i64 (vector_extract (nxv2i64 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, dsub)>; + def : Pat<(f16 (vector_extract (nxv2i64 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, hsub)>; + def : Pat<(f32 (vector_extract (nxv4f32 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, ssub)>; + def : Pat<(f64 (vector_extract (nxv2f64 ZPR:$vec), 0)), + (EXTRACT_SUBREG ZPR:$vec, dsub)>; + // Extract element from vector with immediate index def : Pat<(i32 (vector_extract (nxv16i8 ZPR:$vec), sve_elm_idx_extdup_b:$index)), (EXTRACT_SUBREG (DUP_ZZI_B ZPR:$vec, sve_elm_idx_extdup_b:$index), ssub)>; diff --git a/llvm/test/CodeGen/AArch64/sve-extract-element.ll b/llvm/test/CodeGen/AArch64/sve-extract-element.ll --- a/llvm/test/CodeGen/AArch64/sve-extract-element.ll +++ b/llvm/test/CodeGen/AArch64/sve-extract-element.ll @@ -8,7 +8,6 @@ define i8 @test_lane0_16xi8( %a) { ; CHECK-LABEL: test_lane0_16xi8: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z0.b, b0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret %b = extractelement %a, i32 0 @@ -18,7 +17,6 @@ define i16 @test_lane0_8xi16( %a) { ; CHECK-LABEL: test_lane0_8xi16: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z0.h, h0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret %b = extractelement %a, i32 0 @@ -28,7 +26,6 @@ define i32 @test_lane0_4xi32( %a) { ; CHECK-LABEL: test_lane0_4xi32: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z0.s, s0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret %b = extractelement %a, i32 0 @@ -38,7 +35,6 @@ define i64 @test_lane0_2xi64( %a) { ; CHECK-LABEL: test_lane0_2xi64: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z0.d, d0 ; CHECK-NEXT: fmov x0, d0 ; CHECK-NEXT: ret %b = extractelement %a, i32 0 @@ -183,7 +179,6 @@ define i32 @test_lane64_4xi32( %a) { ; CHECK-LABEL: test_lane64_4xi32: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z0.s, s0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret %b = extractelement %a, i32 undef @@ -249,3 +244,52 @@ %c = extractelement %b, i32 %y ret i64 %c } + +define i1 @test_lane0_16xi1( %a) { +; CHECK-LABEL: test_lane0_16xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 0 + ret i1 %b +} + +define i1 @test_lane9_8xi1( %a) { +; CHECK-LABEL: test_lane9_8xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z0.h, z0.h[9] +; CHECK-NEXT: fmov w8, s0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 9 + ret i1 %b +} + +define i1 @test_lanex_4xi1( %a, i32 %x) { +; CHECK-LABEL: test_lanex_4xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-NEXT: sxtw x8, w0 +; CHECK-NEXT: whilels p1.s, xzr, x8 +; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: lastb w8, p1, z0.s +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 %x + ret i1 %b +} + +define i1 @test_lane4_2xi1( %a) { +; CHECK-LABEL: test_lane4_2xi1: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z0.d, z0.d[4] +; CHECK-NEXT: fmov x8, d0 +; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: ret + %b = extractelement %a, i32 4 + ret i1 %b +} diff --git a/llvm/test/CodeGen/AArch64/sve-insert-element.ll b/llvm/test/CodeGen/AArch64/sve-insert-element.ll --- a/llvm/test/CodeGen/AArch64/sve-insert-element.ll +++ b/llvm/test/CodeGen/AArch64/sve-insert-element.ll @@ -182,9 +182,8 @@ define @test_insert0_of_extract0_16xi8( %a, %b) { ; CHECK-LABEL: test_insert0_of_extract0_16xi8: ; CHECK: // %bb.0: -; CHECK-NEXT: mov z1.b, b1 -; CHECK-NEXT: ptrue p0.b, vl1 ; CHECK-NEXT: fmov w8, s1 +; CHECK-NEXT: ptrue p0.b, vl1 ; CHECK-NEXT: mov z0.b, p0/m, w8 ; CHECK-NEXT: ret %c = extractelement %b, i32 0