Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14024,9 +14024,12 @@ unsigned AArch64TargetLowering::getNumInterleavedAccesses( VectorType *VecTy, const DataLayout &DL, bool UseScalable) const { unsigned VecSize = 128; + unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType()); + auto EC = VecTy->getElementCount(); if (UseScalable) VecSize = std::max(Subtarget->getMinSVEVectorSizeInBits(), 128u); - return std::max(1, (DL.getTypeSizeInBits(VecTy) + 127) / VecSize); + return std::max(1, + (EC.getKnownMinValue() * ElSize + 127) / VecSize); } MachineMemOperand::Flags @@ -14040,29 +14043,35 @@ bool AArch64TargetLowering::isLegalInterleavedAccessType( VectorType *VecTy, const DataLayout &DL, bool &UseScalable) const { - unsigned VecSize = DL.getTypeSizeInBits(VecTy); unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType()); - unsigned NumElements = cast(VecTy)->getNumElements(); - + auto EC = VecTy->getElementCount(); UseScalable = false; // Ensure that the predicate for this number of elements is available. - if (Subtarget->hasSVE() && !getSVEPredPatternFromNumElements(NumElements)) + if (Subtarget->hasSVE() && + !getSVEPredPatternFromNumElements(EC.getKnownMinValue())) return false; // Ensure the number of vector elements is greater than 1. - if (NumElements < 2) + if (EC.getKnownMinValue() < 2) return false; // Ensure the element type is legal. if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64) return false; + if (EC.isScalable()) { + if (EC.getKnownMinValue() * ElSize == 128) + return true; + return false; + } + + unsigned VecSize = DL.getTypeSizeInBits(VecTy); if (Subtarget->forceStreamingCompatibleSVE() || (Subtarget->useSVEForFixedLengthVectors() && (VecSize % Subtarget->getMinSVEVectorSizeInBits() == 0 || (VecSize < Subtarget->getMinSVEVectorSizeInBits() && - isPowerOf2_32(NumElements) && VecSize > 128)))) { + isPowerOf2_32(EC.getKnownMinValue()) && VecSize > 128)))) { UseScalable = true; return true; } Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -2469,21 +2469,29 @@ Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, bool UseMaskForCond, bool UseMaskForGaps) { assert(Factor >= 2 && "Invalid interleave factor"); - auto *VecVTy = cast(VecTy); + auto *VecVTy = cast(VecTy); if (!UseMaskForCond && !UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) { - unsigned NumElts = VecVTy->getNumElements(); + unsigned NumElts = VecVTy->getElementCount().getKnownMinValue(); auto *SubVecTy = - FixedVectorType::get(VecTy->getScalarType(), NumElts / Factor); + VectorType::get(VecVTy->getElementType(), + VecVTy->getElementCount().divideCoefficientBy(Factor)); // ldN/stN only support legal vector types of size 64 or 128 in bits. // Accesses having vector types that are a multiple of 128 bits can be // matched to more than one ldN/stN instruction. bool UseScalable; if (NumElts % Factor == 0 && - TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) - return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable); + TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) { + unsigned Cost = + Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable); + // Deliberately do not move it to the begining of this function + // as we want to execute as much as possible code for scalable vectors. + if (isa(VecTy)) + return InstructionCost::getInvalid(); + return Cost; + } } return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6516,11 +6516,6 @@ InstructionCost LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, ElementCount VF) { - // TODO: Once we have support for interleaving with scalable vectors - // we can calculate the cost properly here. - if (VF.isScalable()) - return InstructionCost::getInvalid(); - Type *ValTy = getLoadStoreType(I); auto *VectorTy = cast(ToVectorTy(ValTy, VF)); unsigned AS = getLoadStoreAddressSpace(I);