diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -76,9 +76,7 @@ MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - - assert(!IsScalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, NumElements); + return getExtendedVectorVT(Context, VT, NumElements, IsScalable); } /// Returns the EVT that represents a vector EC.Min elements in length, @@ -87,19 +85,15 @@ MVT M = MVT::getVectorVT(VT.V, EC); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - assert (!EC.Scalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, EC.Min); + return getExtendedVectorVT(Context, VT, EC); } /// Return a vector with the same number of elements as this vector, but /// with the element type converted to an integer type with the same /// bitwidth. EVT changeVectorElementTypeToInteger() const { - if (!isSimple()) { - assert (!isScalableVector() && - "We don't support extended scalable types yet"); + if (!isSimple()) return changeExtendedVectorElementTypeToInteger(); - } MVT EltTy = getSimpleVT().getVectorElementType(); unsigned BitWidth = EltTy.getSizeInBits(); MVT IntTy = MVT::getIntegerVT(BitWidth); @@ -156,12 +150,7 @@ /// Return true if this is a vector type where the runtime /// length is machine dependent bool isScalableVector() const { - // FIXME: We don't support extended scalable types yet, because the - // matching IR type doesn't exist. Once it has been added, this can - // be changed to call isExtendedScalableVector. - if (!isSimple()) - return false; - return V.isScalableVector(); + return isSimple() ? V.isScalableVector() : isExtendedScalableVector(); } bool isFixedLengthVector() const { @@ -300,9 +289,7 @@ if (isSimple()) return V.getVectorElementCount(); - assert(!isScalableVector() && - "We don't support extended scalable types yet"); - return {getExtendedVectorNumElements(), false}; + return {getExtendedVectorNumElements(), isExtendedScalableVector()}; } /// Return the size of the specified value type in bits. @@ -443,8 +430,10 @@ EVT changeExtendedTypeToInteger() const; EVT changeExtendedVectorElementTypeToInteger() const; static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth); - static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, - unsigned NumElements); + static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements, + bool IsScalable); + static EVT getExtendedVectorVT(LLVMContext &Context, EVT VT, + ElementCount EC); bool isExtendedFloatingPoint() const LLVM_READONLY; bool isExtendedInteger() const LLVM_READONLY; bool isExtendedScalarInteger() const LLVM_READONLY; @@ -458,8 +447,10 @@ bool isExtended1024BitVector() const LLVM_READONLY; bool isExtended2048BitVector() const LLVM_READONLY; bool isExtendedFixedLengthVector() const LLVM_READONLY; + bool isExtendedScalableVector() const LLVM_READONLY; EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; + ElementCount getExtendedVectorElementCount() const LLVM_READONLY; TypeSize getExtendedSizeInBits() const LLVM_READONLY; }; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -415,10 +415,13 @@ // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the // intermediate operands. EVT BuiltVectorTy = - EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(), - (IntermediateVT.isVector() - ? IntermediateVT.getVectorNumElements() * NumParts - : NumIntermediates)); + IntermediateVT.isVector() + ? EVT::getVectorVT( + *DAG.getContext(), IntermediateVT.getScalarType(), + IntermediateVT.getVectorElementCount() * NumParts) + : EVT::getVectorVT(*DAG.getContext(), + IntermediateVT.getScalarType(), + NumIntermediates); Val = DAG.getNode(IntermediateVT.isVector() ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, DL, BuiltVectorTy, Ops); diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -22,7 +22,8 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const { LLVMContext &Context = LLVMTy->getContext(); EVT IntTy = getIntegerVT(Context, getScalarSizeInBits()); - return getVectorVT(Context, IntTy, getVectorNumElements()); + return getVectorVT(Context, IntTy, getVectorNumElements(), + isScalableVector()); } EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) { @@ -32,10 +33,19 @@ return VT; } -EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, - unsigned NumElements) { +EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements, + bool IsScalable) { EVT ResultVT; - ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), NumElements); + ResultVT.LLVMTy = + VectorType::get(VT.getTypeForEVT(Context), NumElements, IsScalable); + assert(ResultVT.isExtended() && "Type is not extended!"); + return ResultVT; +} + +EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, ElementCount EC) { + EVT ResultVT; + ResultVT.LLVMTy = + VectorType::get(VT.getTypeForEVT(Context), {EC.Min, EC.Scalable}); assert(ResultVT.isExtended() && "Type is not extended!"); return ResultVT; } @@ -96,6 +106,10 @@ return isExtendedVector() && !cast(LLVMTy)->isScalable(); } +bool EVT::isExtendedScalableVector() const { + return isExtendedVector() && cast(LLVMTy)->isScalable(); +} + EVT EVT::getExtendedVectorElementType() const { assert(isExtended() && "Type is not extended!"); return EVT::getEVT(cast(LLVMTy)->getElementType()); @@ -106,6 +120,11 @@ return cast(LLVMTy)->getNumElements(); } +ElementCount EVT::getExtendedVectorElementCount() const { + assert(isExtended() && "Type is not extended!"); + return cast(LLVMTy)->getElementCount(); +} + TypeSize EVT::getExtendedSizeInBits() const { assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy))