Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -210,6 +210,13 @@ TypeWidenVector, // This vector should be widened into a larger vector. TypePromoteFloat, // Replace this float with a larger one. TypeSoftPromoteHalf, // Soften half to i16 and use float to do arithmetic. + TypeScalarizeScalableVector, // This action is explicitly left unimplemented. + // While it is theoretically possible to + // legalize operations on scalable types with a + // loop that handles the vscale * #lanes of the + // vector, this is non-trivial at SelectionDAG + // level and these types are better to be + // widened or promoted. }; /// LegalizeKind holds the legalization kind that needs to happen to EVT @@ -412,7 +419,7 @@ virtual TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const { // The default action for one element vectors is to scalarize - if (VT.getVectorNumElements() == 1) + if (VT.getVectorElementCount() == 1) return TypeScalarizeVector; // The default action for an odd-width vector is to widen. if (!VT.isPow2VectorType()) Index: llvm/include/llvm/Support/TypeSize.h =================================================================== --- llvm/include/llvm/Support/TypeSize.h +++ llvm/include/llvm/Support/TypeSize.h @@ -15,6 +15,7 @@ #ifndef LLVM_SUPPORT_TYPESIZE_H #define LLVM_SUPPORT_TYPESIZE_H +#include "llvm/Support/MathExtras.h" #include "llvm/Support/WithColor.h" #include @@ -49,6 +50,12 @@ bool operator!=(const ElementCount& RHS) const { return !(*this == RHS); } + bool operator==(unsigned RHS) const { return Min == RHS && !Scalable; } + bool operator!=(unsigned RHS) const { return !(*this == RHS); } + + ElementCount NextPowerOf2() const { + return ElementCount(llvm::NextPowerOf2(Min), Scalable); + } }; // This class is used to represent the size of types. If the type is of fixed Index: llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -344,6 +344,8 @@ return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, BitConvertToInteger(GetScalarizedVector(InOp))); break; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error("Scalarization of scalable vectors is not supported."); case TargetLowering::TypeSplitVector: { if (!NOutVT.isVector()) { // For example, i32 = BITCAST v2i16 on alpha. Convert the split Index: llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp @@ -245,6 +245,9 @@ case TargetLowering::TypeLegal: LLVM_DEBUG(dbgs() << "Legal result type\n"); break; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error( + "Scalarization of scalable vectors is not supported."); // The following calls must take care of *all* of the node's results, // not just the illegal result they were passed (this includes results // with a legal type). Results can be remapped using ReplaceValueWith, @@ -307,6 +310,9 @@ case TargetLowering::TypeLegal: LLVM_DEBUG(dbgs() << "Legal operand\n"); continue; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error( + "Scalarization of scalable vectors is not supported."); // The following calls must either replace all of the node's results // using ReplaceValueWith, and return "false"; or update the node's // operands in place, and return "true". Index: llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp @@ -83,6 +83,8 @@ Lo = DAG.getNode(ISD::BITCAST, dl, NOutVT, Lo); Hi = DAG.getNode(ISD::BITCAST, dl, NOutVT, Hi); return; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error("Scalarization of scalable vectors is not supported."); case TargetLowering::TypeWidenVector: { assert(!(InVT.getVectorNumElements() & 1) && "Unsupported BITCAST"); InOp = GetWidenedVector(InOp); Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1063,6 +1063,8 @@ Lo = DAG.getNode(ISD::BITCAST, dl, LoVT, Lo); Hi = DAG.getNode(ISD::BITCAST, dl, HiVT, Hi); return; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error("Scalarization of scalable vectors is not supported."); } // In the general case, convert the input to an integer and split it by hand. @@ -3465,6 +3467,8 @@ switch (getTypeAction(InVT)) { case TargetLowering::TypeLegal: break; + case TargetLowering::TypeScalarizeScalableVector: + report_fatal_error("Scalarization of scalable vectors is not supported."); case TargetLowering::TypePromoteInteger: { // If the incoming type is a vector that is being promoted, then // we know that the elements are arranged differently and that we Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -823,9 +823,7 @@ "Promote may not follow Expand or Promote"); if (LA == TypeSplitVector) - return LegalizeKind(LA, - EVT::getVectorVT(Context, SVT.getVectorElementType(), - SVT.getVectorNumElements() / 2)); + return LegalizeKind(LA, SVT.getHalfNumVectorElementsVT()); if (LA == TypeScalarizeVector) return LegalizeKind(LA, SVT.getVectorElementType()); return LegalizeKind(LA, NVT); @@ -852,13 +850,16 @@ } // Handle vector types. - unsigned NumElts = VT.getVectorNumElements(); + ElementCount NumElts = VT.getVectorElementCount(); EVT EltVT = VT.getVectorElementType(); // Vectors with only one element are always scalarized. if (NumElts == 1) return LegalizeKind(TypeScalarizeVector, EltVT); + if (VT.getVectorElementCount() == ElementCount(1, true)) + report_fatal_error("Cannot legalize this vector"); + // Try to widen vector elements until the element type is a power of two and // promote it to a legal type later on, for example: // <3 x i8> -> <4 x i8> -> <4 x i32> @@ -866,7 +867,7 @@ // Vectors with a number of elements that is not a power of two are always // widened, for example <3 x i8> -> <4 x i8>. if (!VT.isPow2VectorType()) { - NumElts = (unsigned)NextPowerOf2(NumElts); + NumElts = NumElts.NextPowerOf2(); EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts); return LegalizeKind(TypeWidenVector, NVT); } @@ -915,7 +916,7 @@ // If there is no wider legal type, split the vector. while (true) { // Round up to the next power of 2. - NumElts = (unsigned)NextPowerOf2(NumElts); + NumElts = NumElts.NextPowerOf2(); // If there is no simple vector type with this many elements then there // cannot be a larger legal vector type. Note that this assumes that @@ -938,7 +939,7 @@ } // Vectors with illegal element types are expanded. - EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorNumElements() / 2); + EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorElementCount() / 2); return LegalizeKind(TypeSplitVector, NVT); } @@ -1261,7 +1262,7 @@ continue; MVT EltVT = VT.getVectorElementType(); - unsigned NElts = VT.getVectorNumElements(); + ElementCount EC = VT.getVectorElementCount(); bool IsLegalWiderType = false; bool IsScalable = VT.isScalableVector(); LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT); @@ -1278,8 +1279,7 @@ // Promote vectors of integers to vectors with the same number // of elements, with a wider element type. if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() && - SVT.getVectorNumElements() == NElts && - SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) { + SVT.getVectorElementCount() == EC && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1294,13 +1294,13 @@ } case TypeWidenVector: - if (isPowerOf2_32(NElts)) { + if (isPowerOf2_32(EC.Min)) { // Try to widen the vector. for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; - if (SVT.getVectorElementType() == EltVT - && SVT.getVectorNumElements() > NElts - && SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) { + if (SVT.getVectorElementType() == EltVT && + SVT.isScalableVector() == IsScalable && + SVT.getVectorElementCount().Min > EC.Min && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1344,10 +1344,12 @@ ValueTypeActions.setTypeAction(VT, TypeScalarizeVector); else if (PreferredAction == TypeSplitVector) ValueTypeActions.setTypeAction(VT, TypeSplitVector); + else if (EC.Min > 1) + ValueTypeActions.setTypeAction(VT, TypeSplitVector); else - // Set type action according to the number of elements. - ValueTypeActions.setTypeAction(VT, NElts == 1 ? TypeScalarizeVector - : TypeSplitVector); + ValueTypeActions.setTypeAction(VT, EC.Scalable + ? TypeScalarizeScalableVector + : TypeScalarizeVector); } else { TransformToType[i] = NVT; ValueTypeActions.setTypeAction(VT, TypeWidenVector); Index: llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp =================================================================== --- llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -17,9 +17,7 @@ #include "llvm/Target/TargetMachine.h" #include "gtest/gtest.h" -using namespace llvm; - -namespace { +namespace llvm { class AArch64SelectionDAGTest : public testing::Test { protected: @@ -41,8 +39,8 @@ return; TargetOptions Options; - TM = std::unique_ptr(static_cast( - T->createTargetMachine("AArch64", "", "", Options, None, None, + TM = std::unique_ptr(static_cast( + T->createTargetMachine("AArch64", "", "+sve", Options, None, None, CodeGenOpt::Aggressive))); if (!TM) return; @@ -69,6 +67,14 @@ DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr); } + TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) { + return DAG->getTargetLoweringInfo().getTypeAction(Context, VT); + } + + EVT getTypeToTransformTo(EVT VT) { + return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT); + } + LLVMContext Context; std::unique_ptr TM; std::unique_ptr M; @@ -377,4 +383,59 @@ EXPECT_EQ(SplatIdx, 0); } -} // end anonymous namespace +TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) { + if (!TM) + return; + + MVT VT = MVT::nxv4i64; + EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector); + ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector()); +} + +TEST_F(AArch64SelectionDAGTest, getTypeConversion_PromoteScalableMVT) { + if (!TM) + return; + + MVT VT = MVT::nxv2i32; + EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypePromoteInteger); + ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector()); +} + +TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeMVT_nxv1f32) { + if (!TM) + return; + + MVT VT = MVT::nxv1f32; + EXPECT_NE(getTypeAction(VT), TargetLoweringBase::TypeScalarizeVector); + ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector()); +} + +TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableEVT) { + if (!TM) + return; + + EVT VT = EVT::getVectorVT(Context, MVT::i64, 256, true); + EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector); + EXPECT_EQ(getTypeToTransformTo(VT), VT.getHalfNumVectorElementsVT(Context)); +} + +TEST_F(AArch64SelectionDAGTest, getTypeConversion_WidenScalableEVT) { + if (!TM) + return; + + EVT FromVT = EVT::getVectorVT(Context, MVT::i64, 6, true); + EVT ToVT = EVT::getVectorVT(Context, MVT::i64, 8, true); + + EXPECT_EQ(getTypeAction(FromVT), TargetLoweringBase::TypeWidenVector); + EXPECT_EQ(getTypeToTransformTo(FromVT), ToVT); +} + +TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeEVT_nxv1f128) { + if (!TM) + return; + + EVT FromVT = EVT::getVectorVT(Context, MVT::f128, 1, true); + EXPECT_DEATH(getTypeAction(FromVT), "Cannot legalize this vector"); +} + +} // end namespace llvm