diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -346,11 +346,6 @@ if (Ops[0]->getType()->isVectorTy()) return std::make_pair(nullptr, nullptr); - // Make sure we're in a loop and that has a pre-header and a single latch. - Loop *L = LI->getLoopFor(GEP->getParent()); - if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) - return std::make_pair(nullptr, nullptr); - Optional VecOperand; unsigned TypeScale = 0; @@ -385,7 +380,37 @@ if (VecIndex->getType() != VecIntPtrTy) return std::make_pair(nullptr, nullptr); - Value *Stride; + // Handle the non-recursive case. This is what we see if the vectorizer + // decides to use a scalar IV + vid on demand instead of a vector IV. + auto [Start, Stride] = matchStridedStart(VecIndex, Builder); + if (Start) { + assert(Stride); + Builder.SetInsertPoint(GEP); + + // Replace the vector index with the scalar start and build a scalar GEP. + Ops[*VecOperand] = Start; + Type *SourceTy = GEP->getSourceElementType(); + Value *BasePtr = + Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); + + // Convert stride to pointer size if needed. + Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); + assert(Stride->getType() == IntPtrTy && "Unexpected type"); + + // Scale the stride by the size of the indexed type. + if (TypeScale != 1) + Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); + + auto P = std::make_pair(BasePtr, Stride); + StridedAddrs[GEP] = P; + return P; + } + + // Make sure we're in a loop and that has a pre-header and a single latch. + Loop *L = LI->getLoopFor(GEP->getParent()); + if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) + return std::make_pair(nullptr, nullptr); + BinaryOperator *Inc; PHINode *BasePhi; if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) diff --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll --- a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll +++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll @@ -50,13 +50,11 @@ ; ; RV64-LABEL: strided_store_zero_start: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, mu -; RV64-NEXT: vid.v v8 -; RV64-NEXT: li a0, 56 -; RV64-NEXT: vmul.vx v8, v8, a0 ; RV64-NEXT: addi a0, a1, 36 -; RV64-NEXT: vmv.v.i v9, 0 -; RV64-NEXT: vsoxei64.v v9, (a0), v8 +; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, mu +; RV64-NEXT: vmv.v.i v8, 0 +; RV64-NEXT: li a1, 56 +; RV64-NEXT: vsse64.v v8, (a0), a1 ; RV64-NEXT: ret %step = tail call @llvm.experimental.stepvector.nxv1i64() %gep = getelementptr inbounds %struct, ptr %p, %step, i32 6 @@ -89,14 +87,13 @@ ; ; RV64-LABEL: strided_store_offset_start: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a2, zero, e64, m1, ta, mu -; RV64-NEXT: vid.v v8 -; RV64-NEXT: vadd.vx v8, v8, a0 -; RV64-NEXT: li a0, 56 -; RV64-NEXT: vmul.vx v8, v8, a0 -; RV64-NEXT: addi a0, a1, 36 -; RV64-NEXT: vmv.v.i v9, 0 -; RV64-NEXT: vsoxei64.v v9, (a0), v8 +; RV64-NEXT: li a2, 56 +; RV64-NEXT: mul a0, a0, a2 +; RV64-NEXT: add a0, a1, a0 +; RV64-NEXT: addi a0, a0, 36 +; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, mu +; RV64-NEXT: vmv.v.i v8, 0 +; RV64-NEXT: vsse64.v v8, (a0), a2 ; RV64-NEXT: ret %step = tail call @llvm.experimental.stepvector.nxv1i64() %.splatinsert = insertelement poison, i64 %n, i64 0 @@ -123,10 +120,9 @@ ; RV64-LABEL: stride_one_store: ; RV64: # %bb.0: ; RV64-NEXT: vsetvli a0, zero, e64, m1, ta, mu -; RV64-NEXT: vid.v v8 -; RV64-NEXT: vsll.vi v8, v8, 3 -; RV64-NEXT: vmv.v.i v9, 0 -; RV64-NEXT: vsoxei64.v v9, (a1), v8 +; RV64-NEXT: vmv.v.i v8, 0 +; RV64-NEXT: li a0, 8 +; RV64-NEXT: vsse64.v v8, (a1), a0 ; RV64-NEXT: ret %step = tail call @llvm.experimental.stepvector.nxv1i64() %gep = getelementptr inbounds i64, ptr %p, %step