diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1301,6 +1301,18 @@ return; } + // insert_subvector(undef, SubVec, 0) is used to represent widening a vector. + // Just return the widened vector. + if (IdxVal == 0 && Vec->getOpcode() == ISD::UNDEF && + TLI.getTypeToTransformTo(*DAG.getContext(), SubVecVT) == VecVT) { + SDValue Widened = GetWidenedVector(SubVec); + Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Lo.getValueType(), Widened, + DAG.getVectorIdxConstant(0, dl)); + Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Hi.getValueType(), Widened, + DAG.getVectorIdxConstant(LoElems, dl)); + return; + } + // Spill the vector to the stack. // In cases where the vector is illegal it will be broken down into parts // and stored in parts - we should use the alignment for the smallest part. @@ -3527,7 +3539,7 @@ SDLoc DL(N); EVT WidenVT = TLI.getTypeToTransformTo(Ctx, N->getValueType(0)); - unsigned WidenNumElts = WidenVT.getVectorNumElements(); + ElementCount WidenNumElts = WidenVT.getVectorElementCount(); EVT InVT = InOp.getValueType(); @@ -3548,12 +3560,12 @@ EVT InEltVT = InVT.getVectorElementType(); EVT InWidenVT = EVT::getVectorVT(Ctx, InEltVT, WidenNumElts); - unsigned InVTNumElts = InVT.getVectorNumElements(); + ElementCount InVTNumElts = InVT.getVectorElementCount(); if (getTypeAction(InVT) == TargetLowering::TypeWidenVector) { InOp = GetWidenedVector(N->getOperand(0)); InVT = InOp.getValueType(); - InVTNumElts = InVT.getVectorNumElements(); + InVTNumElts = InVT.getVectorElementCount(); if (InVTNumElts == WidenNumElts) { if (N->getNumOperands() == 1) return DAG.getNode(Opcode, DL, WidenVT, InOp); @@ -3578,9 +3590,10 @@ // it an illegal type that might lead to repeatedly splitting the input // and then widening it. To avoid this, we widen the input only if // it results in a legal type. - if (WidenNumElts % InVTNumElts == 0) { + if (WidenNumElts.getKnownMinValue() % InVTNumElts.getKnownMinValue() == 0) { // Widen the input and call convert on the widened input vector. - unsigned NumConcat = WidenNumElts/InVTNumElts; + unsigned NumConcat = + WidenNumElts.getKnownMinValue() / InVTNumElts.getKnownMinValue(); SmallVector Ops(NumConcat, DAG.getUNDEF(InVT)); Ops[0] = InOp; SDValue InVec = DAG.getNode(ISD::CONCAT_VECTORS, DL, InWidenVT, Ops); @@ -3589,7 +3602,7 @@ return DAG.getNode(Opcode, DL, WidenVT, InVec, N->getOperand(1), Flags); } - if (InVTNumElts % WidenNumElts == 0) { + if (InVTNumElts.getKnownMinValue() % WidenNumElts.getKnownMinValue() == 0) { SDValue InVal = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InWidenVT, InOp, DAG.getVectorIdxConstant(0, DL)); // Extract the input and convert the shorten input vector. @@ -3599,9 +3612,13 @@ } } + if (WidenNumElts.isScalable()) + report_fatal_error("Don't know how to legalize scalable vector conversion"); + // Otherwise unroll into some nasty scalar code and rebuild the vector. EVT EltVT = WidenVT.getVectorElementType(); - SmallVector Ops(WidenNumElts, DAG.getUNDEF(EltVT)); + SmallVector Ops(WidenNumElts.getFixedValue(), + DAG.getUNDEF(EltVT)); // Use the original element count so we don't do more scalar opts than // necessary. unsigned MinElts = N->getValueType(0).getVectorNumElements(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -399,29 +399,31 @@ return Val; if (PartEVT.isVector()) { + // Vector/Vector bitcast. + if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) + return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + // If the element type of the source/dest vectors are the same, but the // parts vector has more elements than the value vector, then we have a // vector widening case (e.g. <2 x float> -> <4 x float>). Extract the // elements we want. - if (PartEVT.getVectorElementType() == ValueVT.getVectorElementType()) { + if (PartEVT.getVectorElementCount() != ValueVT.getVectorElementCount()) { assert((PartEVT.getVectorElementCount().getKnownMinValue() > ValueVT.getVectorElementCount().getKnownMinValue()) && (PartEVT.getVectorElementCount().isScalable() == ValueVT.getVectorElementCount().isScalable()) && "Cannot narrow, it would be a lossy transformation"); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val, - DAG.getVectorIdxConstant(0, DL)); + PartEVT = + EVT::getVectorVT(*DAG.getContext(), PartEVT.getVectorElementType(), + ValueVT.getVectorElementCount()); + Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartEVT, Val, + DAG.getVectorIdxConstant(0, DL)); + if (PartEVT == ValueVT) + return Val; } - // Vector/Vector bitcast. - if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) - return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); - - assert(PartEVT.getVectorElementCount() == ValueVT.getVectorElementCount() && - "Cannot handle this kind of promotion"); // Promoted vector extract return DAG.getAnyExtOrTrunc(Val, DL, ValueVT); - } // Trivial bitcast if the types are the same size and the destination @@ -726,15 +728,19 @@ } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) { // Bitconvert vector->vector case. Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val); - } else if (SDValue Widened = - widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) { - Val = Widened; - } else if (BuiltVectorTy.getVectorElementType().bitsGE( - ValueVT.getVectorElementType()) && - BuiltVectorTy.getVectorElementCount() == - ValueVT.getVectorElementCount()) { - // Promoted vector extract - Val = DAG.getAnyExtOrTrunc(Val, DL, BuiltVectorTy); + } else { + if (BuiltVectorTy.getVectorElementType().bitsGT( + ValueVT.getVectorElementType())) { + // Integer promotion. + ValueVT = EVT::getVectorVT(*DAG.getContext(), + BuiltVectorTy.getVectorElementType(), + ValueVT.getVectorElementCount()); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, ValueVT, Val); + } + + if (SDValue Widened = widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) { + Val = Widened; + } } assert(Val.getValueType() == BuiltVectorTy && "Unexpected vector value type"); diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1556,7 +1556,7 @@ // Scalable vectors cannot be scalarized, so handle the legalisation of the // types like done elsewhere in SelectionDAG. - if (VT.isScalableVector() && !isPowerOf2_32(EltCnt.getKnownMinValue())) { + if (EltCnt.isScalable()) { LegalizeKind LK; EVT PartVT = VT; do { @@ -1565,16 +1565,14 @@ PartVT = LK.second; } while (LK.first != TypeLegal); - NumIntermediates = VT.getVectorElementCount().getKnownMinValue() / - PartVT.getVectorElementCount().getKnownMinValue(); + if (!PartVT.isVector()) { + report_fatal_error( + "Don't know how to legalize this scalable vector type"); + } - // FIXME: This code needs to be extended to handle more complex vector - // breakdowns, like nxv7i64 -> nxv8i64 -> 4 x nxv2i64. Currently the only - // supported cases are vectors that are broken down into equal parts - // such as nxv6i64 -> 3 x nxv2i64. - assert((PartVT.getVectorElementCount() * NumIntermediates) == - VT.getVectorElementCount() && - "Expected an integer multiple of PartVT"); + NumIntermediates = + divideCeil(VT.getVectorElementCount().getKnownMinValue(), + PartVT.getVectorElementCount().getKnownMinValue()); IntermediateVT = PartVT; RegisterVT = getRegisterType(Context, IntermediateVT); return NumIntermediates; diff --git a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll --- a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll +++ b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll @@ -689,3 +689,64 @@ L2: ret %illegal } + +define @wide_8i63(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_8i63: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB21_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: mov z2.d, z3.d +; CHECK-NEXT: mov z3.d, z4.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB21_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +} + +define @wide_7i63(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_7i63: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB22_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: mov z2.d, z3.d +; CHECK-NEXT: mov z3.d, z4.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB22_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +} + +define @wide_7i31(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_7i31: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB23_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB23_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +}