Index: llvm/include/llvm/CodeGen/ValueTypes.h =================================================================== --- llvm/include/llvm/CodeGen/ValueTypes.h +++ llvm/include/llvm/CodeGen/ValueTypes.h @@ -103,6 +103,17 @@ return VecTy; } + /// Return a VT for a vector type whose attributes match ourselves + /// with the exception of the element type that is chosen by the caller. + EVT changeVectorElementType(EVT EltVT) const { + if (!isSimple()) + return changeExtendedVectorElementType(EltVT); + MVT VecTy = MVT::getVectorVT(EltVT.V, getVectorElementCount()); + assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE && + "Simple vector VT not representable by simple integer vector VT!"); + return VecTy; + } + /// Return the type converted to an equivalently sized integer or vector /// with integer element type. Similar to changeVectorElementTypeToInteger, /// but also handles scalars. @@ -432,6 +443,7 @@ // These are all out-of-line to prevent users of this header file // from having a dependency on Type.h. EVT changeExtendedTypeToInteger() const; + EVT changeExtendedVectorElementType(EVT EltVT) const; EVT changeExtendedVectorElementTypeToInteger() const; static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth); static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements, Index: llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -4324,6 +4324,31 @@ SDLoc dl(N); SDValue BaseIdx = N->getOperand(1); + // TODO: We may be able to use this for types other than scalable + // vectors and fix those tests that expect BUILD_VECTOR to be used + if (OutVT.isScalableVector()) { + SDValue InOp0 = N->getOperand(0); + EVT InVT = InOp0.getValueType(); + + // Promote operands and see if this is handled by target lowering, + // Otherwise, use the BUILD_VECTOR approach below + if (getTypeAction(InVT) == TargetLowering::TypePromoteInteger) { + // Collect the (promoted) operands + SDValue Ops[] = { GetPromotedInteger(InOp0), BaseIdx }; + + EVT PromEltVT = Ops[0].getValueType().getVectorElementType(); + assert(PromEltVT.bitsLE(NOutVTElem) && + "Promoted operand has an element type greater than result"); + + EVT ExtVT = NOutVT.changeVectorElementType(PromEltVT); + SDValue Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), ExtVT, Ops); + return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Ext); + } + } + + if (OutVT.isScalableVector()) + report_fatal_error("Unable to promote scalable types using BUILD_VECTOR"); + SDValue InOp0 = N->getOperand(0); if (getTypeAction(InOp0.getValueType()) == TargetLowering::TypePromoteInteger) InOp0 = GetPromotedInteger(N->getOperand(0)); Index: llvm/lib/CodeGen/ValueTypes.cpp =================================================================== --- llvm/lib/CodeGen/ValueTypes.cpp +++ llvm/lib/CodeGen/ValueTypes.cpp @@ -26,6 +26,11 @@ isScalableVector()); } +EVT EVT::changeExtendedVectorElementType(EVT EltVT) const { + LLVMContext &Context = LLVMTy->getContext(); + return getVectorVT(Context, EltVT, getVectorElementCount()); +} + EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) { EVT VT; VT.LLVMTy = IntegerType::get(Context, BitWidth); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -888,6 +888,9 @@ void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; + void ReplaceExtractSubVectorResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const; bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -901,6 +901,9 @@ setOperationAction(ISD::SRA, VT, Custom); if (VT.getScalarType() == MVT::i1) setOperationAction(ISD::SETCC, VT, Custom); + } else { + for (auto VT : { MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32 }) + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); } } setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); @@ -8559,6 +8562,9 @@ SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const { + assert(!Op.getValueType().isScalableVector() && + "Unexpected scalable type for custom lowering EXTRACT_SUBVECTOR"); + EVT VT = Op.getOperand(0).getValueType(); SDLoc dl(Op); // Just in case... @@ -10661,7 +10667,45 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); + SelectionDAG &DAG = DCI.DAG; SDValue Src = N->getOperand(0); + unsigned Opc = Src->getOpcode(); + + // Zero/any extend of an unsigned unpack + if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) { + SDValue UnpkOp = Src->getOperand(0); + SDValue Dup = N->getOperand(1); + + if (Dup.getOpcode() != AArch64ISD::DUP) + return SDValue(); + + SDLoc DL(N); + ConstantSDNode *C = dyn_cast(Dup->getOperand(0)); + uint64_t ExtVal = C->getZExtValue(); + + // If the mask is fully covered by the unpack, we don't need to push + // a new AND onto the operand + EVT EltTy = UnpkOp->getValueType(0).getVectorElementType(); + if ((ExtVal == 0xFF && EltTy == MVT::i8) || + (ExtVal == 0xFFFF && EltTy == MVT::i16) || + (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32)) + return Src; + + // Truncate to prevent a DUP with an over wide constant + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, EltTy, Dup->getOperand(0)); + + // Otherwise, make sure we propagate the AND to the operand + // of the unpack + Dup = DAG.getNode(AArch64ISD::DUP, DL, + UnpkOp->getValueType(0), + DAG.getAnyExtOrTrunc(Trunc, DL, MVT::i32)); + + SDValue And = DAG.getNode(ISD::AND, DL, + UnpkOp->getValueType(0), UnpkOp, Dup); + + return DAG.getNode(Opc, DL, N->getValueType(0), And); + } + SDValue Mask = N->getOperand(1); if (!Src.hasOneUse()) @@ -10671,7 +10715,7 @@ // SVE load instructions perform an implicit zero-extend, which makes them // perfect candidates for combining. - switch (Src->getOpcode()) { + switch (Opc) { case AArch64ISD::LD1: case AArch64ISD::LDNF1: case AArch64ISD::LDFF1: @@ -13252,9 +13296,41 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); + SDLoc DL(N); SDValue Src = N->getOperand(0); unsigned Opc = Src->getOpcode(); + // Sign extend of an unsigned unpack -> signed unpack + if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) { + + unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI + : AArch64ISD::SUNPKLO; + + // Push the sign extend to the operand of the unpack + // This is necessary where, for example, the operand of the unpack + // is another unpack: + // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8) + // -> + // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8) + // -> + // 4i32 sunpklo(8i16 sunpklo(16i8 opnd)) + SDValue ExtOp = Src->getOperand(0); + auto VT = cast(N->getOperand(1))->getVT(); + EVT EltTy = VT.getVectorElementType(); + + assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) && + "Sign extending from an invalid type"); + + EVT ExtVT = EVT::getVectorVT(*DAG.getContext(), + VT.getVectorElementType(), + VT.getVectorElementCount() * 2); + + SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(), + ExtOp, DAG.getValueType(ExtVT)); + + return DAG.getNode(SOpc, DL, N->getValueType(0), Ext); + } + // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes. unsigned NewOpc; @@ -13741,6 +13817,40 @@ return std::make_pair(Lo, Hi); } +void AArch64TargetLowering::ReplaceExtractSubVectorResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + + // Common code will handle these just fine. + if (!InVT.isScalableVector() || !InVT.isInteger()) + return; + + SDLoc DL(N); + EVT VT = N->getValueType(0); + + // The following checks bail if this is not a halving operation. + + ElementCount ResEC = VT.getVectorElementCount(); + + if (InVT.getVectorElementCount().Min != (ResEC.Min * 2)) + return; + + auto *CIndex = dyn_cast(N->getOperand(1)); + if (!CIndex) + return; + + unsigned Index = CIndex->getZExtValue(); + if ((Index != 0) && (Index != ResEC.Min)) + return; + + unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI; + EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext()); + + SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half)); +} + // Create an even/odd pair of X registers holding integer value V. static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) { SDLoc dl(V.getNode()); @@ -13893,6 +14003,9 @@ Results.append({Pair, Result.getValue(2) /* Chain */}); return; } + case ISD::EXTRACT_SUBVECTOR: + ReplaceExtractSubVectorResults(N, Results, DAG); + return; case ISD::INTRINSIC_WO_CHAIN: { EVT VT = N->getValueType(0); assert((VT == MVT::i8 || VT == MVT::i16) && Index: llvm/test/CodeGen/AArch64/sve-sext-zext.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-sext-zext.ll +++ llvm/test/CodeGen/AArch64/sve-sext-zext.ll @@ -186,3 +186,143 @@ %r = zext %a to ret %r } + +; Extending to illegal types + +define @sext_b_to_h( %a) { +; CHECK-LABEL: sext_b_to_h: +; CHECK: // %bb.0: +; CHECK-NEXT: sunpklo z2.h, z0.b +; CHECK-NEXT: sunpkhi z1.h, z0.b +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = sext %a to + ret %ext +} + +define @sext_h_to_s( %a) { +; CHECK-LABEL: sext_h_to_s: +; CHECK: // %bb.0: +; CHECK-NEXT: sunpklo z2.s, z0.h +; CHECK-NEXT: sunpkhi z1.s, z0.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = sext %a to + ret %ext +} + +define @sext_s_to_d( %a) { +; CHECK-LABEL: sext_s_to_d: +; CHECK: // %bb.0: +; CHECK-NEXT: sunpklo z2.d, z0.s +; CHECK-NEXT: sunpkhi z1.d, z0.s +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = sext %a to + ret %ext +} + +define @sext_b_to_s( %a) { +; CHECK-LABEL: sext_b_to_s: +; CHECK: // %bb.0: +; CHECK-NEXT: sunpklo z1.h, z0.b +; CHECK-NEXT: sunpkhi z3.h, z0.b +; CHECK-NEXT: sunpklo z0.s, z1.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: sunpklo z2.s, z3.h +; CHECK-NEXT: sunpkhi z3.s, z3.h +; CHECK-NEXT: ret + %ext = sext %a to + ret %ext +} + +define @sext_b_to_d( %a) { +; CHECK-LABEL: sext_b_to_d: +; CHECK: // %bb.0: +; CHECK-NEXT: sunpklo z1.h, z0.b +; CHECK-NEXT: sunpkhi z0.h, z0.b +; CHECK-NEXT: sunpklo z2.s, z1.h +; CHECK-NEXT: sunpkhi z3.s, z1.h +; CHECK-NEXT: sunpklo z5.s, z0.h +; CHECK-NEXT: sunpkhi z7.s, z0.h +; CHECK-NEXT: sunpklo z0.d, z2.s +; CHECK-NEXT: sunpkhi z1.d, z2.s +; CHECK-NEXT: sunpklo z2.d, z3.s +; CHECK-NEXT: sunpkhi z3.d, z3.s +; CHECK-NEXT: sunpklo z4.d, z5.s +; CHECK-NEXT: sunpkhi z5.d, z5.s +; CHECK-NEXT: sunpklo z6.d, z7.s +; CHECK-NEXT: sunpkhi z7.d, z7.s +; CHECK-NEXT: ret + %ext = sext %a to + ret %ext +} + +define @zext_b_to_h( %a) { +; CHECK-LABEL: zext_b_to_h: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z2.h, z0.b +; CHECK-NEXT: uunpkhi z1.h, z0.b +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = zext %a to + ret %ext +} + +define @zext_h_to_s( %a) { +; CHECK-LABEL: zext_h_to_s: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z2.s, z0.h +; CHECK-NEXT: uunpkhi z1.s, z0.h +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = zext %a to + ret %ext +} + +define @zext_s_to_d( %a) { +; CHECK-LABEL: zext_s_to_d: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z2.d, z0.s +; CHECK-NEXT: uunpkhi z1.d, z0.s +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %ext = zext %a to + ret %ext +} + +define @zext_b_to_s( %a) { +; CHECK-LABEL: zext_b_to_s: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z1.h, z0.b +; CHECK-NEXT: uunpkhi z3.h, z0.b +; CHECK-NEXT: uunpklo z0.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: uunpklo z2.s, z3.h +; CHECK-NEXT: uunpkhi z3.s, z3.h +; CHECK-NEXT: ret + %ext = zext %a to + ret %ext +} + +define @zext_b_to_d( %a) { +; CHECK-LABEL: zext_b_to_d: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z1.h, z0.b +; CHECK-NEXT: uunpkhi z0.h, z0.b +; CHECK-NEXT: uunpklo z2.s, z1.h +; CHECK-NEXT: uunpkhi z3.s, z1.h +; CHECK-NEXT: uunpklo z5.s, z0.h +; CHECK-NEXT: uunpkhi z7.s, z0.h +; CHECK-NEXT: uunpklo z0.d, z2.s +; CHECK-NEXT: uunpkhi z1.d, z2.s +; CHECK-NEXT: uunpklo z2.d, z3.s +; CHECK-NEXT: uunpkhi z3.d, z3.s +; CHECK-NEXT: uunpklo z4.d, z5.s +; CHECK-NEXT: uunpkhi z5.d, z5.s +; CHECK-NEXT: uunpklo z6.d, z7.s +; CHECK-NEXT: uunpkhi z7.d, z7.s +; CHECK-NEXT: ret + %ext = zext %a to + ret %ext +}