diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -12894,10 +12894,9 @@ SDValue BasePtr = BaseLd->getBasePtr(); // Go through the loads and check that they're strided - SDValue CurPtr = BasePtr; - SDValue Stride; + SmallVector Ptrs; + Ptrs.push_back(BasePtr); Align Align = BaseLd->getAlign(); - for (SDValue Op : N->ops().drop_front()) { auto *Ld = dyn_cast(Op); if (!Ld || !Ld->isSimple() || !Op.hasOneUse() || @@ -12905,27 +12904,66 @@ Ld->getValueType(0) != BaseLdVT) return SDValue(); - SDValue Ptr = Ld->getBasePtr(); - // Check that each load's pointer is (add CurPtr, Stride) - if (Ptr.getOpcode() != ISD::ADD || Ptr.getOperand(0) != CurPtr) - return SDValue(); - SDValue Offset = Ptr.getOperand(1); - if (!Stride) - Stride = Offset; - else if (Offset != Stride) - return SDValue(); + Ptrs.push_back(Ld->getBasePtr()); // The common alignment is the most restrictive (smallest) of all the loads Align = std::min(Align, Ld->getAlign()); + } - CurPtr = Ptr; + auto matchForwardStrided = [](ArrayRef Ptrs) { + SDValue Stride; + for (auto Idx : enumerate(Ptrs)) { + if (Idx.index() == 0) + continue; + SDValue Ptr = Idx.value(); + // Check that each load's pointer is (add LastPtr, Stride) + if (Ptr.getOpcode() != ISD::ADD || + Ptr.getOperand(0) != Ptrs[Idx.index()-1]) + return SDValue(); + SDValue Offset = Ptr.getOperand(1); + if (!Stride) + Stride = Offset; + else if (Offset != Stride) + return SDValue(); + } + return Stride; + }; + auto matchReverseStrided = [](ArrayRef Ptrs) { + SDValue Stride; + for (auto Idx : enumerate(Ptrs)) { + if (Idx.index() == Ptrs.size() - 1) + continue; + SDValue Ptr = Idx.value(); + // Check that each load's pointer is (add NextPtr, Stride) + if (Ptr.getOpcode() != ISD::ADD || + Ptr.getOperand(0) != Ptrs[Idx.index()+1]) + return SDValue(); + SDValue Offset = Ptr.getOperand(1); + if (!Stride) + Stride = Offset; + else if (Offset != Stride) + return SDValue(); + } + return Stride; + }; + + bool Reversed = false; + SDValue Stride = matchForwardStrided(Ptrs); + if (!Stride) { + Stride = matchReverseStrided(Ptrs); + Reversed = true; + // TODO: At this point, we've successfully matched a generalized gather + // load. Maybe we should emit that, and then move the specialized + // matchers above and below into a DAG combine? + if (!Stride) + return SDValue(); } // A special case is if the stride is exactly the width of one of the loads, // in which case it's contiguous and can be combined into a regular vle // without changing the element size if (auto *ConstStride = dyn_cast(Stride); - ConstStride && + ConstStride && !Reversed && ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) { MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), @@ -12962,6 +13000,8 @@ SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other}); SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT()); + if (Reversed) + Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0)); SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(ContainerVT), @@ -12970,7 +13010,8 @@ VL}; uint64_t MemSize; - if (auto *ConstStride = dyn_cast(Stride)) + if (auto *ConstStride = dyn_cast(Stride); + ConstStride && !Reversed && ConstStride->getSExtValue() >= 0) // total size = (elsize * n) + (stride - elsize) * (n-1) // = elsize + stride * (n-1) MemSize = WideScalarVT.getSizeInBits() + diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll @@ -494,25 +494,15 @@ ret void } -; TODO: This is a strided load with a negative stride +; This is a strided load with a negative stride define void @reverse_strided_constant_pos_4xv2f32(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: reverse_strided_constant_pos_4xv2f32: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, 64 -; CHECK-NEXT: addi a3, a0, 128 -; CHECK-NEXT: addi a4, a0, 192 -; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma -; CHECK-NEXT: vle32.v v8, (a4) -; CHECK-NEXT: vle32.v v10, (a3) -; CHECK-NEXT: vle32.v v12, (a2) -; CHECK-NEXT: vle32.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 2 -; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 4 -; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 6 -; CHECK-NEXT: vse32.v v8, (a1) +; CHECK-NEXT: addi a0, a0, 192 +; CHECK-NEXT: li a2, -64 +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %x.1 = getelementptr i8, ptr %x, i64 64 %x.2 = getelementptr i8, ptr %x.1, i64 64 @@ -531,21 +521,11 @@ define void @reverse_strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: reverse_strided_constant_neg_4xv2f32: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, -64 -; CHECK-NEXT: addi a3, a0, -128 -; CHECK-NEXT: addi a4, a0, -192 -; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma -; CHECK-NEXT: vle32.v v8, (a4) -; CHECK-NEXT: vle32.v v10, (a3) -; CHECK-NEXT: vle32.v v12, (a2) -; CHECK-NEXT: vle32.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 2 -; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 4 -; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 6 -; CHECK-NEXT: vse32.v v8, (a1) +; CHECK-NEXT: addi a0, a0, -192 +; CHECK-NEXT: li a2, 64 +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %x.1 = getelementptr i8, ptr %x, i64 -64 %x.2 = getelementptr i8, ptr %x.1, i64 -64 @@ -561,25 +541,17 @@ ret void } -; TODO: This is a strided load with a negative stride +; This is a strided load with a negative stride define void @reverse_strided_runtime_4xv2f32(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: reverse_strided_runtime_4xv2f32: ; CHECK: # %bb.0: -; CHECK-NEXT: add a3, a0, a2 -; CHECK-NEXT: add a4, a3, a2 -; CHECK-NEXT: add a2, a4, a2 -; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma -; CHECK-NEXT: vle32.v v8, (a2) -; CHECK-NEXT: vle32.v v10, (a4) -; CHECK-NEXT: vle32.v v12, (a3) -; CHECK-NEXT: vle32.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 2 -; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 4 -; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 6 -; CHECK-NEXT: vse32.v v8, (a1) +; CHECK-NEXT: add a0, a0, a2 +; CHECK-NEXT: add a3, a2, a2 +; CHECK-NEXT: add a0, a0, a3 +; CHECK-NEXT: neg a2, a2 +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %x.1 = getelementptr i8, ptr %x, i64 %s %x.2 = getelementptr i8, ptr %x.1, i64 %s