diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -9196,9 +9196,10 @@
 /// or upper half of the vector elements.
 static bool areExtractShuffleVectors(Value *Op1, Value *Op2) {
   auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
-    auto *FullVT = cast<VectorType>(FullV->getType());
-    auto *HalfVT = cast<VectorType>(HalfV->getType());
-    return FullVT->getBitWidth() == 2 * HalfVT->getBitWidth();
+    auto *FullTy = FullV->getType();
+    auto *HalfTy = HalfV->getType();
+    return FullTy->getPrimitiveSizeInBits().getFixedSize() ==
+           2 * HalfTy->getPrimitiveSizeInBits().getFixedSize();
   };
 
   auto extractHalf = [](Value *FullV, Value *HalfV) {