diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1301,6 +1301,34 @@ return; } + // insert_subvector(Op, SubVec, 0) where SubVec widens to the result type + // can be converted to a vselect. + if (IdxVal == 0 && VecVT.isScalableVector() && + TLI.getTypeToTransformTo(*DAG.getContext(), SubVecVT) == VecVT) { + SDValue WidenedSubVec = GetWidenedVector(SubVec); + EVT CmpElementVT = MVT::i32; + EVT CmpVT = EVT::getVectorVT(*DAG.getContext(), CmpElementVT, + VecVT.getVectorElementCount()); + EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + VecVT.getVectorElementCount()); + + SDLoc dl(N); + SDValue Step = DAG.getStepVector(dl, CmpVT); + unsigned NumElements = SubVec.getValueType().getVectorMinNumElements(); + SDValue SplatNumElements = DAG.getSplatVector( + CmpVT, dl, DAG.getVScale(dl, CmpElementVT, APInt(32, NumElements))); + SDValue Mask = + DAG.getSetCC(dl, MaskVT, Step, SplatNumElements, ISD::SETULT); + SDValue Select = + DAG.getNode(ISD::VSELECT, dl, VecVT, Mask, WidenedSubVec, Vec); + + Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Lo.getValueType(), Select, + DAG.getVectorIdxConstant(0, dl)); + Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Hi.getValueType(), Select, + DAG.getVectorIdxConstant(LoElems, dl)); + return; + } + // Spill the vector to the stack. // In cases where the vector is illegal it will be broken down into parts // and stored in parts - we should use the alignment for the smallest part. @@ -4015,6 +4043,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_LOAD(SDNode *N) { LoadSDNode *LD = cast(N); ISD::LoadExtType ExtType = LD->getExtensionType(); + EVT VT = N->getValueType(0); + + // FIXME: Figure out how to replace constant "2". + if (VT.isScalableVector() && VT.getVectorMinNumElements() % 2 != 0) { + // Convert load to masked load. Let MLOAD legalization handle widening. + // (We assume hardware with scalable vectors supports masked load/store.) + SDLoc dl(N); + EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); + EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + VT.getVectorElementCount()); + SDValue Mask = DAG.getAllOnesConstant(dl, MaskVT); + SDValue PassThru = DAG.getUNDEF(VT); + + // Convert load to masked load. Let MLOAD legalization handle widening. + SDValue Res = DAG.getMaskedLoad(VT, dl, LD->getChain(), LD->getBasePtr(), + LD->getOffset(), Mask, PassThru, + LD->getMemoryVT(), LD->getMemOperand(), + LD->getAddressingMode(), ExtType); + + // Legalize the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + + // Widen the result. + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WidenVT, + DAG.getUNDEF(WidenVT), Res, + DAG.getVectorIdxConstant(0, dl)); + } // A vector must always be stored in memory as-is, i.e. without any padding // between the elements, since various code depend on it, e.g. in the @@ -4054,8 +4110,10 @@ } SDValue DAGTypeLegalizer::WidenVecRes_MLOAD(MaskedLoadSDNode *N) { + assert(N->getAddressingMode() == ISD::UNINDEXED && + "We shouldn't form indexed loads with illegal types"); - EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(),N->getValueType(0)); + EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue Mask = N->getMask(); EVT MaskVT = Mask.getValueType(); SDValue PassThru = GetWidenedVector(N->getPassThru()); @@ -4063,15 +4121,23 @@ SDLoc dl(N); // The mask should be widened as well - EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), - MaskVT.getVectorElementType(), - WidenVT.getVectorNumElements()); + EVT WideMaskVT = + EVT::getVectorVT(*DAG.getContext(), MaskVT.getVectorElementType(), + WidenVT.getVectorElementCount()); Mask = ModifyToType(Mask, WideMaskVT, true); - SDValue Res = DAG.getMaskedLoad( - WidenVT, dl, N->getChain(), N->getBasePtr(), N->getOffset(), Mask, - PassThru, N->getMemoryVT(), N->getMemOperand(), N->getAddressingMode(), - ExtType, N->isExpandingLoad()); + EVT MemVT = N->getMemoryVT(); + EVT WideMemVT = + EVT::getVectorVT(*DAG.getContext(), MemVT.getVectorElementType(), + WidenVT.getVectorElementCount()); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MemOp = MF.getMachineMemOperand( + N->getMemOperand(), 0, MemoryLocation::UnknownSize); + + SDValue Res = + DAG.getMaskedLoad(WidenVT, dl, N->getChain(), N->getBasePtr(), + N->getOffset(), Mask, PassThru, WideMemVT, MemOp, + ISD::UNINDEXED, ExtType, N->isExpandingLoad()); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); @@ -4864,6 +4930,26 @@ N->getConstantOperandVal(2) == 0) return SubVec; + if (SubVec.getValueType() == InVec.getValueType() && + InVec.getValueType().isScalableVector() && + N->getConstantOperandVal(2) == 0) { + EVT VT = InVec.getValueType(); + EVT CmpElementVT = MVT::i32; + EVT CmpVT = EVT::getVectorVT(*DAG.getContext(), CmpElementVT, + VT.getVectorElementCount()); + EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, + VT.getVectorElementCount()); + + SDLoc dl(N); + SDValue Step = DAG.getStepVector(dl, CmpVT); + unsigned NumElements = SubVec.getValueType().getVectorMinNumElements(); + SDValue SplatNumElements = DAG.getSplatVector( + CmpVT, dl, DAG.getVScale(dl, CmpElementVT, APInt(32, NumElements))); + SDValue Mask = + DAG.getSetCC(dl, MaskVT, Step, SplatNumElements, ISD::SETULT); + return DAG.getNode(ISD::VSELECT, dl, VT, Mask, SubVec, InVec); + } + report_fatal_error("Don't know how to widen the operands for " "INSERT_SUBVECTOR"); } @@ -5542,8 +5628,8 @@ if (InVT == NVT) return InOp; - unsigned InNumElts = InVT.getVectorNumElements(); - unsigned WidenNumElts = NVT.getVectorNumElements(); + unsigned InNumElts = InVT.getVectorMinNumElements(); + unsigned WidenNumElts = NVT.getVectorMinNumElements(); if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) { unsigned NumConcat = WidenNumElts / InNumElts; SmallVector Ops(NumConcat); @@ -5560,6 +5646,17 @@ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, DAG.getVectorIdxConstant(0, dl)); + if (NVT.isScalableVector()) { + if (WidenNumElts < InNumElts) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, + DAG.getVectorIdxConstant(0, dl)); + + SDValue FillVal = + FillWithZeroes ? DAG.getConstant(0, dl, NVT) : DAG.getUNDEF(NVT); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, NVT, FillVal, InOp, + DAG.getVectorIdxConstant(0, dl)); + } + // Fall back to extract and build. SmallVector Ops(WidenNumElts); EVT EltVT = NVT.getVectorElementType(); diff --git a/llvm/test/CodeGen/AArch64/sve-split-load.ll b/llvm/test/CodeGen/AArch64/sve-split-load.ll --- a/llvm/test/CodeGen/AArch64/sve-split-load.ll +++ b/llvm/test/CodeGen/AArch64/sve-split-load.ll @@ -36,6 +36,41 @@ ret %load } +define @load_widen_3i32(* %a) { +; CHECK-LABEL: load_widen_3i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cntw x8 +; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: mov z1.s, w8 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z1.s, z0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: ret + %load = load , * %a + %r = call @llvm.experimental.vector.insert.nxv4i32.nxv3i32( undef, %load, i64 0) + ret %r +} + +define @load_widen_7i32(* %a) { +; CHECK-LABEL: load_widen_7i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cntw x9 +; CHECK-NEXT: cnth x8 +; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: mov z2.s, w9 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z1.s, w8 +; CHECK-NEXT: add z2.s, z0.s, z2.s +; CHECK-NEXT: cmphi p1.s, p0/z, z1.s, z0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z1.s, z2.s +; CHECK-NEXT: ld1w { z0.s }, p1/z, [x0] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: ret + %load = load , * %a + %r = call @llvm.experimental.vector.insert.nxv8i32.nxv7i32( undef, %load, i64 0) + ret %r +} + define @load_split_32i16(* %a) { ; CHECK-LABEL: load_split_32i16: ; CHECK: // %bb.0: @@ -136,11 +171,60 @@ ret %load } -declare @llvm.masked.load.nxv32i8(*, i32, , ) +define @masked_load_widen_3i32(* %a, %pg.wide) { +; CHECK-LABEL: masked_load_widen_3i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cntw x8 +; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: mov z1.s, w8 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: cmphi p2.s, p1/z, z1.s, z0.s +; CHECK-NEXT: and p0.b, p1/z, p2.b, p0.b +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: ret + %pg = call @llvm.experimental.vector.extract.nxv3i1.nxv4i1( %pg.wide, i64 0) + %load = call @llvm.masked.load.nxv3i32( *%a, i32 1, %pg, undef) + %r = call @llvm.experimental.vector.insert.nxv4i32.nxv3i32( undef, %load, i64 0) + ret %r +} + +define @masked_load_widen_7i32(* %a, %pg.wide) { +; CHECK-LABEL: masked_load_widen_7i32: +; CHECK: // %bb.0: +; CHECK-NEXT: cnth x8 +; CHECK-NEXT: mov z1.s, w8 +; CHECK-NEXT: cntw x8 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: mov z2.s, w8 +; CHECK-NEXT: cmphi p2.s, p1/z, z1.s, z0.s +; CHECK-NEXT: add z0.s, z0.s, z2.s +; CHECK-NEXT: cmphi p1.s, p1/z, z1.s, z0.s +; CHECK-NEXT: uzp1 p1.h, p2.h, p1.h +; CHECK-NEXT: ptrue p2.h +; CHECK-NEXT: and p0.b, p2/z, p1.b, p0.b +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.h, p0.h, p1.h +; CHECK-NEXT: zip2 p0.h, p0.h, p1.h +; CHECK-NEXT: ld1w { z0.s }, p2/z, [x0] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: ret + %pg = call @llvm.experimental.vector.extract.nxv7i1.nxv8i1( %pg.wide, i64 0) + %load = call @llvm.masked.load.nxv7i32( *%a, i32 1, %pg, undef) + %r = call @llvm.experimental.vector.insert.nxv8i32.nxv7i32( undef, %load, i64 0) + ret %r +} + -declare @llvm.masked.load.nxv32i16(*, i32, , ) +declare @llvm.masked.load.nxv32i8(*, i32, , ) +declare @llvm.masked.load.nxv32i16(*, i32, , ) declare @llvm.masked.load.nxv2i32(*, i32, , ) +declare @llvm.masked.load.nxv3i32(*, i32, , ) +declare @llvm.masked.load.nxv7i32(*, i32, , ) declare @llvm.masked.load.nxv8i32(*, i32, , ) - declare @llvm.masked.load.nxv8i64(*, i32, , ) +declare @llvm.experimental.vector.insert.nxv4i32.nxv3i32(, , i64) +declare @llvm.experimental.vector.insert.nxv8i32.nxv7i32(, , i64) +declare @llvm.experimental.vector.extract.nxv3i1.nxv4i1(, i64) +declare @llvm.experimental.vector.extract.nxv7i1.nxv8i1(, i64)