Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -497,6 +497,7 @@ SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, std::vector *Created) const override; Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -621,6 +621,18 @@ setOperationAction(ISD::FTRUNC, Ty, Legal); setOperationAction(ISD::FROUND, Ty, Legal); } + + // We support custom legalization of extended loads that we can load as + // scalars and then extend in-register. This prevents us from generating + // multiple loads and insertions. + for (MVT VT : MVT::integer_vector_valuetypes()) { + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v4i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, VT, MVT::v2i16, Custom); + setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i8, Custom); + setLoadExtAction(ISD::EXTLOAD, VT, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, VT, MVT::v2i16, Custom); + } } // Prefer likely predicted branches to selects on out-of-order cores. @@ -2288,6 +2300,114 @@ } } +SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { + + MVT RegVT = Op.getSimpleValueType(); + assert(RegVT.isVector() && "We only custom lower vector sext loads"); + assert(RegVT.isInteger() && "We only custom lower integer vector sext loads"); + + LoadSDNode *Ld = cast(Op.getNode()); + SDLoc dl(Ld); + EVT MemVT = Ld->getMemoryVT(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + unsigned RegSz = RegVT.getSizeInBits(); + + // The extension type should be any-extend or sign-extend. + ISD::LoadExtType Ext = Ld->getExtensionType(); + assert((Ext == ISD::EXTLOAD || Ext == ISD::SEXTLOAD) && + "Only anyext and sext are currently implemented"); + assert(MemVT != RegVT && "Cannot extend to the same type"); + assert(MemVT.isVector() && "Must load a vector from memory"); + + // The number of vector elements and their total size. + unsigned NumElems = RegVT.getVectorNumElements(); + unsigned MemSz = MemVT.getSizeInBits(); + assert(RegSz > MemSz && "Register size must be greater than the mem size"); + + // All sizes must be a power of two. + assert(isPowerOf2_32(RegSz * MemSz * NumElems) && + "Non-power-of-two elements are not custom lowered"); + + // We attempt to load the original value using scalar loads. First, find the + // largest scalar type that divides the total loaded size. + MVT SclrLoadTy = MVT::i8; + for (MVT VT : MVT::integer_vector_valuetypes()) + if (TLI.isTypeLegal(VT)) + if (MemSz % VT.getScalarType().getSizeInBits() == 0) + SclrLoadTy = VT.getScalarType(); + + // Calculate the number of scalar loads that we need to perform in order to + // load our vector from memory. + unsigned NumLoads = MemSz / SclrLoadTy.getSizeInBits(); + + assert((Ext != ISD::SEXTLOAD || NumLoads == 1) && + "Can only lower sext loads with a single scalar load!"); + + // We represent our vector as a sequence of elements that are the largest + // scalars that we can load. + EVT LoadUnitVecVT = EVT::getVectorVT(*DAG.getContext(), SclrLoadTy, + RegSz / SclrLoadTy.getSizeInBits()); + + // We represent the data using the same element type that is stored in + // memory. + EVT WideVecVT = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), + RegSz / MemVT.getScalarSizeInBits()); + + assert(WideVecVT.getSizeInBits() == LoadUnitVecVT.getSizeInBits() && + "Invalid vector type"); + + // We will perform the extensions using vector shuffles, so we need to ensure + // the type is not illegal. + assert(TLI.isTypeLegal(WideVecVT) && + "We only lower types that form legal widened vector types"); + + SmallVector Chains; + SDValue Ptr = Ld->getBasePtr(); + SDValue Increment = DAG.getConstant(SclrLoadTy.getSizeInBits() / 8, dl, + TLI.getPointerTy(DAG.getDataLayout())); + SDValue Res = DAG.getUNDEF(LoadUnitVecVT); + + // Perform the scalar single loads. + for (unsigned i = 0; i < NumLoads; ++i) { + SDValue ScalarLoad = + DAG.getLoad(SclrLoadTy, dl, Ld->getChain(), Ptr, Ld->getPointerInfo(), + Ld->isVolatile(), Ld->isNonTemporal(), Ld->isInvariant(), + Ld->getAlignment()); + Chains.push_back(ScalarLoad.getValue(1)); + Res = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, LoadUnitVecVT, Res, + ScalarLoad, DAG.getIntPtrConstant(i, dl)); + Ptr = DAG.getNode(ISD::ADD, dl, Ptr.getValueType(), Ptr, Increment); + } + + SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Chains); + + // Bitcast the loaded value to a vector of the original element type, in the + // size of the target vector type. + SDValue SlicedVec = DAG.getBitcast(WideVecVT, Res); + unsigned SizeRatio = RegSz / MemSz; + + // Sign extend the vector. This will be legalized to a shuffle and shifts. + if (Ext == ISD::SEXTLOAD) { + SDValue Shuff = DAG.getSignExtendVectorInReg(SlicedVec, dl, RegVT); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); + return Shuff; + } + + // If we are not sign extending the vector, just shuffle the loaded elements + // into the different locations. + SmallVector ShuffleVec(NumElems * SizeRatio, -1); + for (unsigned i = 0; i != NumElems; ++i) + ShuffleVec[i * SizeRatio] = i; + + SDValue Shuff = DAG.getVectorShuffle(WideVecVT, dl, SlicedVec, + DAG.getUNDEF(WideVecVT), &ShuffleVec[0]); + + // Finally, bitcast to the result requested type. + Shuff = DAG.getBitcast(RegVT, Shuff); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), TF); + return Shuff; +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -2391,6 +2511,8 @@ return LowerMUL(Op, DAG); case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); + case ISD::LOAD: + return LowerLOAD(Op, DAG); } } Index: test/CodeGen/AArch64/neon-truncStore-extLoad.ll =================================================================== --- test/CodeGen/AArch64/neon-truncStore-extLoad.ll +++ test/CodeGen/AArch64/neon-truncStore-extLoad.ll @@ -29,29 +29,26 @@ ret void } -; A vector LoadExt can not be selected. -; Test a vector load IR and a sext/zext IR can be selected correctly. -define <4 x i32> @loadSExt.v4i8(<4 x i8>* %ref) { -; CHECK-LABEL: loadSExt.v4i8: -; CHECK: ldrsb - %a = load <4 x i8>, <4 x i8>* %ref - %conv = sext <4 x i8> %a to <4 x i32> - ret <4 x i32> %conv +define <2 x i8> @loadExt.v2i8(<2 x i8>* %ref) { +; CHECK-LABEL: loadExt.v2i8: +; CHECK: ld1 { [[REG:v[0-9]+]].h }[0], [x0] +; CHECK: ins [[REG]].b[4], [[REG]].b[1] + %a = load <2 x i8>, <2 x i8>* %ref + ret <2 x i8> %a } -define <4 x i32> @loadZExt.v4i8(<4 x i8>* %ref) { -; CHECK-LABEL: loadZExt.v4i8: -; CHECK: ldrb +define <4 x i8> @loadExt.v4i8(<4 x i8>* %ref) { +; CHECK-LABEL: loadExt.v4i8: +; CHECK: ld1 { [[REG:v[0-9]+]].s }[0], [x0] +; CHECK: zip1 {{v[0-9]+}}.8b, [[REG]].8b, {{v[0-9]+}}.8b %a = load <4 x i8>, <4 x i8>* %ref - %conv = zext <4 x i8> %a to <4 x i32> - ret <4 x i32> %conv + ret <4 x i8> %a } -define i32 @loadExt.i32(<4 x i8>* %ref) { -; CHECK-LABEL: loadExt.i32: -; CHECK: ldrb - %a = load <4 x i8>, <4 x i8>* %ref - %vecext = extractelement <4 x i8> %a, i32 0 - %conv = zext i8 %vecext to i32 - ret i32 %conv +define <2 x i16> @loadExt.v2i16(<2 x i16>* %ref) { +; CHECK-LABEL: loadExt.v2i16: +; CHECK: ld1 { [[REG:v[0-9]+]].s }[0], [x0] +; CHECK: zip1 {{v[0-9]+}}.4h, [[REG]].4h, {{v[0-9]+}}.4h + %a = load <2 x i16>, <2 x i16>* %ref + ret <2 x i16> %a }