diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1122,7 +1122,8 @@ if (Subtarget.hasVInstructions()) setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL, - ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR}); + ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR, + ISD::CONCAT_VECTORS}); if (Subtarget.hasVendorXTHeadMemPair()) setTargetDAGCombine({ISD::LOAD, ISD::STORE}); if (Subtarget.useRVVForFixedLengthVectors()) @@ -11018,6 +11019,136 @@ return tryFoldSelectIntoOp(N, DAG, FalseVal, TrueVal, /*Swapped*/true); } +// If we're concatenating a series of vector loads like +// concat_vectors (load v4i8, p+0), (load v4i8, p+n), (load v4i8, p+n*2) ... +// Then we can turn this into a strided load by widening the vector elements +// vlse32 p, stride=n +static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + const RISCVTargetLowering &TLI) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + // Only perform this combine on legal MVTs. + if (!TLI.isTypeLegal(VT)) + return SDValue(); + + // TODO: Potentially extend this to scalable vectors + if (VT.isScalableVector()) + return SDValue(); + + auto *BaseLd = dyn_cast(N->getOperand(0)); + if (!BaseLd || !BaseLd->isSimple() || !ISD::isNormalLoad(BaseLd) || + !SDValue(BaseLd, 0).hasOneUse()) + return SDValue(); + + EVT BaseLdVT = BaseLd->getValueType(0); + SDValue BasePtr = BaseLd->getBasePtr(); + + // Go through the loads and check that they're strided + SDValue CurPtr = BasePtr; + SDValue Stride; + Align Align = BaseLd->getAlign(); + + for (SDValue Op : N->ops().drop_front()) { + auto *Ld = dyn_cast(Op); + if (!Ld || !Ld->isSimple() || !Op.hasOneUse() || + Ld->getChain() != BaseLd->getChain() || !ISD::isNormalLoad(Ld) || + Ld->getValueType(0) != BaseLdVT) + return SDValue(); + + SDValue Ptr = Ld->getBasePtr(); + // Check that each load's pointer is (add CurPtr, Stride) + if (Ptr.getOpcode() != ISD::ADD || Ptr.getOperand(0) != CurPtr) + return SDValue(); + SDValue Offset = Ptr.getOperand(1); + if (!Stride) + Stride = Offset; + else if (Offset != Stride) + return SDValue(); + + // The common alignment is the most restrictive (smallest) of all the loads + Align = std::min(Align, Ld->getAlign()); + + CurPtr = Ptr; + } + + // A special case is if the stride is exactly the width of one of the loads, + // in which case it's contiguous and can be combined into a regular vle + // without changing the element size + if (auto *ConstStride = dyn_cast(Stride); + ConstStride && + ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) { + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), + VT.getStoreSize(), Align); + // Can't do the combine if the load isn't naturally aligned with the element + // type + if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(), + DAG.getDataLayout(), VT, *MMO)) + return SDValue(); + + SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO); + for (SDValue Ld : N->ops()) + DAG.makeEquivalentMemoryOrdering(cast(Ld), WideLoad); + return WideLoad; + } + + // Get the widened scalar type, e.g. v4i8 -> i64 + unsigned WideScalarBitWidth = + BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements(); + MVT WideScalarVT = MVT::getIntegerVT(WideScalarBitWidth); + + // Get the vector type for the strided load, e.g. 4 x v4i8 -> v4i64 + MVT WideVecVT = MVT::getVectorVT(WideScalarVT, N->getNumOperands()); + if (!TLI.isTypeLegal(WideVecVT)) + return SDValue(); + + MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT); + SDValue VL = + getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second; + SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other}); + SDValue IntID = + DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT()); + SDValue Ops[] = {BaseLd->getChain(), + IntID, + DAG.getUNDEF(ContainerVT), + BasePtr, + Stride, + VL}; + + uint64_t MemSize; + if (auto *ConstStride = dyn_cast(Stride)) + // total size = (elsize * n) + (stride - elsize) * (n-1) + // = elsize + stride * (n-1) + MemSize = WideScalarVT.getSizeInBits() + + ConstStride->getSExtValue() * (N->getNumOperands() - 1); + else + // If Stride isn't constant, then we can't know how much it will load + MemSize = MemoryLocation::UnknownSize; + + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), MemSize, + Align); + + // Can't do the combine if the common alignment isn't naturally aligned with + // the new element type + if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(), + DAG.getDataLayout(), WideVecVT, *MMO)) + return SDValue(); + + SDValue StridedLoad = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, + Ops, WideVecVT, MMO); + for (SDValue Ld : N->ops()) + DAG.makeEquivalentMemoryOrdering(cast(Ld), StridedLoad); + + // Note: Perform the bitcast before the convertFromScalableVector so we have + // balanced pairs of convertFromScalable/convertToScalable + SDValue Res = DAG.getBitcast( + TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad); + return convertFromScalableVector(VT, Res, DAG, Subtarget); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -11525,6 +11656,10 @@ return Gather; break; } + case ISD::CONCAT_VECTORS: + if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this)) + return V; + break; case RISCVISD::VMV_V_X_VL: { // Tail agnostic VMV.V.X only demands the vector element bitwidth from the // scalar input. diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll @@ -6,12 +6,8 @@ define void @widen_2xv4i16(ptr %x, ptr %z) { ; CHECK-LABEL: widen_2xv4i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: addi a0, a0, 8 -; CHECK-NEXT: vle16.v v9, (a0) ; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 +; CHECK-NEXT: vle16.v v8, (a0) ; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x @@ -74,20 +70,8 @@ define void @widen_4xv4i16(ptr %x, ptr %z) { ; CHECK-LABEL: widen_4xv4i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: addi a2, a0, 8 -; CHECK-NEXT: vle16.v v10, (a2) -; CHECK-NEXT: addi a2, a0, 16 -; CHECK-NEXT: vle16.v v12, (a2) -; CHECK-NEXT: addi a0, a0, 24 -; CHECK-NEXT: vle16.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 12, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 8 ; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 12 +; CHECK-NEXT: vle16.v v8, (a0) ; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x @@ -108,13 +92,10 @@ define void @strided_constant(ptr %x, ptr %z) { ; CHECK-LABEL: strided_constant: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: addi a0, a0, 16 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: li a2, 16 +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 16 @@ -128,13 +109,10 @@ define void @strided_constant_64(ptr %x, ptr %z) { ; CHECK-LABEL: strided_constant_64: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: addi a0, a0, 64 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: li a2, 64 +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 64 @@ -219,13 +197,9 @@ define void @strided_runtime(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_runtime: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 %s @@ -238,21 +212,9 @@ define void @strided_runtime_4xv4i16(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_runtime_4xv4i16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v10, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v12, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 12, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 8 -; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 12 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 %s @@ -324,21 +286,9 @@ define void @strided_runtime_4xv4f16(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_runtime_4xv4f16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v10, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v12, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 12, e16, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 8 -; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 12 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x half>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 %s @@ -357,21 +307,9 @@ define void @strided_runtime_4xv2f32(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_runtime_4xv2f32: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma -; CHECK-NEXT: vle32.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle32.v v10, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle32.v v12, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle32.v v14, (a0) -; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 2 -; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v12, 4 -; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v14, 6 -; CHECK-NEXT: vse32.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <2 x float>, ptr %x %b.gep = getelementptr i8, ptr %x, i64 %s @@ -406,17 +344,13 @@ ret void } -; Shouldn't be combined because the loads have different alignments +; Should use the most restrictive common alignment define void @strided_mismatched_alignments(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_mismatched_alignments: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x, align 8 %b.gep = getelementptr i8, ptr %x, i64 %s @@ -429,13 +363,9 @@ define void @strided_ok_alignments_8(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_ok_alignments_8: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x, align 8 %b.gep = getelementptr i8, ptr %x, i64 %s @@ -448,13 +378,9 @@ define void @strided_ok_alignments_16(ptr %x, ptr %z, i64 %s) { ; CHECK-LABEL: strided_ok_alignments_16: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma -; CHECK-NEXT: vle16.v v8, (a0) -; CHECK-NEXT: add a0, a0, a2 -; CHECK-NEXT: vle16.v v9, (a0) -; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma -; CHECK-NEXT: vslideup.vi v8, v9, 4 -; CHECK-NEXT: vse16.v v8, (a1) +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v8, (a0), a2 +; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: ret %a = load <4 x i16>, ptr %x, align 16 %b.gep = getelementptr i8, ptr %x, i64 %s