Index: llvm/include/llvm/CodeGen/ValueTypes.h =================================================================== --- llvm/include/llvm/CodeGen/ValueTypes.h +++ llvm/include/llvm/CodeGen/ValueTypes.h @@ -75,9 +75,7 @@ MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - - assert(!IsScalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, NumElements); + return getExtendedVectorVT(Context, VT, NumElements, IsScalable); } /// Returns the EVT that represents a vector EC.Min elements in length, @@ -86,19 +84,15 @@ MVT M = MVT::getVectorVT(VT.V, EC); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - assert (!EC.Scalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, EC.Min); + return getExtendedVectorVT(Context, VT, EC); } /// Return a vector with the same number of elements as this vector, but /// with the element type converted to an integer type with the same /// bitwidth. EVT changeVectorElementTypeToInteger() const { - if (!isSimple()) { - assert (!isScalableVector() && - "We don't support extended scalable types yet"); + if (!isSimple()) return changeExtendedVectorElementTypeToInteger(); - } MVT EltTy = getSimpleVT().getVectorElementType(); unsigned BitWidth = EltTy.getSizeInBits(); MVT IntTy = MVT::getIntegerVT(BitWidth); @@ -155,12 +149,7 @@ /// Return true if this is a vector type where the runtime /// length is machine dependent bool isScalableVector() const { - // FIXME: We don't support extended scalable types yet, because the - // matching IR type doesn't exist. Once it has been added, this can - // be changed to call isExtendedScalableVector. - if (!isSimple()) - return false; - return V.isScalableVector(); + return isSimple() ? V.isScalableVector() : isExtendedScalableVector(); } /// Return true if this is a 16-bit vector type. @@ -265,7 +254,7 @@ /// Given a vector type, return the type of each element. EVT getVectorElementType() const { - assert(isVector() && "Invalid vector type!"); + assert((isVector() || isScalableVector()) && "Invalid vector type!"); if (isSimple()) return V.getVectorElementType(); return getExtendedVectorElementType(); @@ -285,9 +274,7 @@ if (isSimple()) return V.getVectorElementCount(); - assert(!isScalableVector() && - "We don't support extended scalable types yet"); - return {getExtendedVectorNumElements(), false}; + return {getExtendedVectorNumElements(), isExtendedScalableVector()}; } /// Return the size of the specified value type in bits. @@ -428,8 +415,10 @@ EVT changeExtendedTypeToInteger() const; EVT changeExtendedVectorElementTypeToInteger() const; static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth); - static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, - unsigned NumElements); + static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements, + bool IsScalable); + static EVT getExtendedVectorVT(LLVMContext &Context, EVT VT, + ElementCount EC); bool isExtendedFloatingPoint() const LLVM_READONLY; bool isExtendedInteger() const LLVM_READONLY; bool isExtendedScalarInteger() const LLVM_READONLY; @@ -442,8 +431,10 @@ bool isExtended512BitVector() const LLVM_READONLY; bool isExtended1024BitVector() const LLVM_READONLY; bool isExtended2048BitVector() const LLVM_READONLY; + bool isExtendedScalableVector() const LLVM_READONLY; EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; + ElementCount getExtendedVectorElementCount() const LLVM_READONLY; TypeSize getExtendedSizeInBits() const LLVM_READONLY; }; Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -415,10 +415,13 @@ // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the // intermediate operands. EVT BuiltVectorTy = - EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(), - (IntermediateVT.isVector() - ? IntermediateVT.getVectorNumElements() * NumParts - : NumIntermediates)); + IntermediateVT.isVector() + ? EVT::getVectorVT( + *DAG.getContext(), IntermediateVT.getScalarType(), + IntermediateVT.getVectorElementCount() * NumParts) + : EVT::getVectorVT(*DAG.getContext(), + IntermediateVT.getScalarType(), + NumIntermediates); Val = DAG.getNode(IntermediateVT.isVector() ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, DL, BuiltVectorTy, Ops); Index: llvm/lib/CodeGen/ValueTypes.cpp =================================================================== --- llvm/lib/CodeGen/ValueTypes.cpp +++ llvm/lib/CodeGen/ValueTypes.cpp @@ -22,7 +22,8 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const { LLVMContext &Context = LLVMTy->getContext(); EVT IntTy = getIntegerVT(Context, getScalarSizeInBits()); - return getVectorVT(Context, IntTy, getVectorNumElements()); + return getVectorVT(Context, IntTy, getVectorNumElements(), + isScalableVector()); } EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) { @@ -32,10 +33,19 @@ return VT; } -EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, - unsigned NumElements) { +EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, + bool IsScalable) { EVT ResultVT; - ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), NumElements); + ResultVT.LLVMTy = + VectorType::get(VT.getTypeForEVT(Context), NumElements, IsScalable); + assert(ResultVT.isExtended() && "Type is not extended!"); + return ResultVT; +} + +EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, ElementCount EC) { + EVT ResultVT; + ResultVT.LLVMTy = + VectorType::get(VT.getTypeForEVT(Context), {EC.Min, EC.Scalable}); assert(ResultVT.isExtended() && "Type is not extended!"); return ResultVT; } @@ -92,6 +102,10 @@ return isExtendedVector() && getExtendedSizeInBits() == 2048; } +bool EVT::isExtendedScalableVector() const { + return isExtendedVector() && cast(LLVMTy)->isScalable(); +} + EVT EVT::getExtendedVectorElementType() const { assert(isExtended() && "Type is not extended!"); return EVT::getEVT(cast(LLVMTy)->getElementType()); @@ -102,6 +116,11 @@ return cast(LLVMTy)->getNumElements(); } +ElementCount EVT::getExtendedVectorElementCount() const { + assert(isExtended() && "Type is not extended!"); + return cast(LLVMTy)->getElementCount(); +} + TypeSize EVT::getExtendedSizeInBits() const { assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy))