Index: llvm/lib/Transforms/Vectorize/VectorCombine.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1134,14 +1134,14 @@ if (!match(&I, m_Load(m_Value(Ptr)))) return false; - auto *FixedVT = cast(I.getType()); + auto *VecTy = cast(I.getType()); auto *LI = cast(&I); const DataLayout &DL = I.getModule()->getDataLayout(); - if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT)) + if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy)) return false; InstructionCost OriginalCost = - TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(), + TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(), LI->getPointerAddressSpace()); InstructionCost ScalarizedCost = 0; @@ -1172,7 +1172,7 @@ LastCheckedInst = UI; } - auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); + auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT); if (!ScalarIdx.isSafe()) { // TODO: Freeze index if it is safe to do so. ScalarIdx.discard(); @@ -1182,12 +1182,12 @@ auto *Index = dyn_cast(UI->getOperand(1)); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; OriginalCost += - TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind, + TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind, Index ? Index->getZExtValue() : -1); ScalarizedCost += - TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(), + TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(), Align(1), LI->getPointerAddressSpace()); - ScalarizedCost += TTI.getAddressComputationCost(FixedVT->getElementType()); + ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType()); } if (ScalarizedCost >= OriginalCost) @@ -1200,12 +1200,12 @@ Value *Idx = EI->getOperand(1); Value *GEP = - Builder.CreateInBoundsGEP(FixedVT, Ptr, {Builder.getInt32(0), Idx}); + Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx}); auto *NewLoad = cast(Builder.CreateLoad( - FixedVT->getElementType(), GEP, EI->getName() + ".scalar")); + VecTy->getElementType(), GEP, EI->getName() + ".scalar")); Align ScalarOpAlignment = computeAlignmentAfterScalarization( - LI->getAlign(), FixedVT->getElementType(), Idx, DL); + LI->getAlign(), VecTy->getElementType(), Idx, DL); NewLoad->setAlignment(ScalarOpAlignment); replaceValue(*EI, *NewLoad); @@ -1727,9 +1727,6 @@ case Instruction::ShuffleVector: MadeChange |= widenSubvectorLoad(I); break; - case Instruction::Load: - MadeChange |= scalarizeLoadExtract(I); - break; default: break; } @@ -1743,6 +1740,8 @@ if (Opcode == Instruction::Store) MadeChange |= foldSingleElementStore(I); + if (isa(I.getType()) && Opcode == Instruction::Load) + MadeChange |= scalarizeLoadExtract(I); // If this is an early pipeline invocation of this pass, we are done. if (TryEarlyFoldsOnly) Index: llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll =================================================================== --- llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll +++ llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll @@ -15,8 +15,8 @@ define i32 @vscale_load_extract_idx_0(ptr %x) { ; CHECK-LABEL: @vscale_load_extract_idx_0( -; CHECK-NEXT: [[LV:%.*]] = load , ptr [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement [[LV]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds , ptr [[X:%.*]], i32 0, i32 0 +; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 16 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load , ptr %x @@ -61,8 +61,8 @@ define i32 @vscale_load_extract_idx_2(ptr %x) { ; CHECK-LABEL: @vscale_load_extract_idx_2( -; CHECK-NEXT: [[LV:%.*]] = load , ptr [[X:%.*]], align 16 -; CHECK-NEXT: [[R:%.*]] = extractelement [[LV]], i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds , ptr [[X:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP1]], align 8 ; CHECK-NEXT: ret i32 [[R]] ; %lv = load , ptr %x @@ -142,9 +142,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[LV:%.*]] = load , ptr [[X:%.*]], align 16 ; CHECK-NEXT: call void @maythrow() -; CHECK-NEXT: [[R:%.*]] = extractelement [[LV]], i64 [[IDX]] +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds , ptr [[X:%.*]], i32 0, i64 [[IDX]] +; CHECK-NEXT: [[R:%.*]] = load i32, ptr [[TMP0]], align 4 ; CHECK-NEXT: ret i32 [[R]] ; entry: