diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -5558,7 +5558,10 @@ NewAddr = Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + auto *SecondTy = GetElementPtrInst::getIndexedType( + SourceTy, makeArrayRef(Ops).drop_front()); + NewAddr = + Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy)); } else { Value *Base = Ops[0]; Value *Index = Ops[FinalIndex]; @@ -5569,10 +5572,12 @@ Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); Base = Builder.CreateGEP(SourceTy, Base, makeArrayRef(Ops).drop_front()); + SourceTy = GetElementPtrInst::getIndexedType( + SourceTy, makeArrayRef(Ops).drop_front()); } // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(Base, Index); + NewAddr = Builder.CreateGEP(SourceTy, Base, Index); } } else if (!isa(Ptr)) { // Not a GEP, maybe its a splat and we can create a GEP to enable @@ -5588,7 +5593,16 @@ // Emit a vector GEP with a scalar pointer and all 0s vector index. Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts); - NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy)); + Type *ScalarTy; + if (cast(MemoryInst)->getIntrinsicID() == + Intrinsic::masked_gather) { + ScalarTy = MemoryInst->getType()->getScalarType(); + } else { + assert(cast(MemoryInst)->getIntrinsicID() == + Intrinsic::masked_scatter); + ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType(); + } + NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy)); } else { // Constant, SelectionDAGBuilder knows to check if its a splat. return false;