diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -399,29 +399,31 @@ return Val; if (PartEVT.isVector()) { + // Vector/Vector bitcast. + if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) + return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + // If the element type of the source/dest vectors are the same, but the // parts vector has more elements than the value vector, then we have a // vector widening case (e.g. <2 x float> -> <4 x float>). Extract the // elements we want. - if (PartEVT.getVectorElementType() == ValueVT.getVectorElementType()) { + if (PartEVT.getVectorElementCount() != ValueVT.getVectorElementCount()) { assert((PartEVT.getVectorElementCount().getKnownMinValue() > ValueVT.getVectorElementCount().getKnownMinValue()) && (PartEVT.getVectorElementCount().isScalable() == ValueVT.getVectorElementCount().isScalable()) && "Cannot narrow, it would be a lossy transformation"); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val, - DAG.getVectorIdxConstant(0, DL)); + PartEVT = + EVT::getVectorVT(*DAG.getContext(), PartEVT.getVectorElementType(), + ValueVT.getVectorElementCount()); + Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartEVT, Val, + DAG.getVectorIdxConstant(0, DL)); + if (PartEVT == ValueVT) + return Val; } - // Vector/Vector bitcast. - if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) - return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); - - assert(PartEVT.getVectorElementCount() == ValueVT.getVectorElementCount() && - "Cannot handle this kind of promotion"); // Promoted vector extract return DAG.getAnyExtOrTrunc(Val, DL, ValueVT); - } // Trivial bitcast if the types are the same size and the destination @@ -726,15 +728,19 @@ } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) { // Bitconvert vector->vector case. Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val); - } else if (SDValue Widened = - widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) { - Val = Widened; - } else if (BuiltVectorTy.getVectorElementType().bitsGE( - ValueVT.getVectorElementType()) && - BuiltVectorTy.getVectorElementCount() == - ValueVT.getVectorElementCount()) { - // Promoted vector extract - Val = DAG.getAnyExtOrTrunc(Val, DL, BuiltVectorTy); + } else { + if (BuiltVectorTy.getVectorElementType().bitsGT( + ValueVT.getVectorElementType())) { + // Integer promotion. + ValueVT = EVT::getVectorVT(*DAG.getContext(), + BuiltVectorTy.getVectorElementType(), + ValueVT.getVectorElementCount()); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, ValueVT, Val); + } + + if (SDValue Widened = widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) { + Val = Widened; + } } assert(Val.getValueType() == BuiltVectorTy && "Unexpected vector value type"); diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1556,7 +1556,7 @@ // Scalable vectors cannot be scalarized, so handle the legalisation of the // types like done elsewhere in SelectionDAG. - if (VT.isScalableVector() && !isPowerOf2_32(EltCnt.getKnownMinValue())) { + if (EltCnt.isScalable()) { LegalizeKind LK; EVT PartVT = VT; do { @@ -1565,16 +1565,14 @@ PartVT = LK.second; } while (LK.first != TypeLegal); - NumIntermediates = VT.getVectorElementCount().getKnownMinValue() / - PartVT.getVectorElementCount().getKnownMinValue(); + if (!PartVT.isVector()) { + report_fatal_error( + "Don't know how to legalize this scalable vector type"); + } - // FIXME: This code needs to be extended to handle more complex vector - // breakdowns, like nxv7i64 -> nxv8i64 -> 4 x nxv2i64. Currently the only - // supported cases are vectors that are broken down into equal parts - // such as nxv6i64 -> 3 x nxv2i64. - assert((PartVT.getVectorElementCount() * NumIntermediates) == - VT.getVectorElementCount() && - "Expected an integer multiple of PartVT"); + NumIntermediates = + divideCeil(VT.getVectorElementCount().getKnownMinValue(), + PartVT.getVectorElementCount().getKnownMinValue()); IntermediateVT = PartVT; RegisterVT = getRegisterType(Context, IntermediateVT); return NumIntermediates; diff --git a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll --- a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll +++ b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll @@ -689,3 +689,64 @@ L2: ret %illegal } + +define @wide_8i63(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_8i63: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB21_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: mov z2.d, z3.d +; CHECK-NEXT: mov z3.d, z4.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB21_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +} + +define @wide_7i63(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_7i63: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB22_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: mov z2.d, z3.d +; CHECK-NEXT: mov z3.d, z4.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB22_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +} + +define @wide_7i31(i1 %b, %legal, %illegal) nounwind { +; CHECK-LABEL: wide_7i31: +; CHECK: // %bb.0: +; CHECK-NEXT: tbnz w0, #0, .LBB23_2 +; CHECK-NEXT: // %bb.1: // %L2 +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: mov z1.d, z2.d +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB23_2: // %L1 +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl bar + br i1 %b, label %L1, label %L2 +L1: + call aarch64_sve_vector_pcs void @bar() + unreachable +L2: + ret %illegal +}