Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -242,6 +242,21 @@ return true; } + /// isVectorToScalarLoadStoreWidenBeneficial() - Return true if the vector + /// load or store packing into a larger scalar type is beneficial. + /// Width: Width to load/store. + /// WidenVT: The widen vector type to load to/store from. + /// N: Load or store SDNode. + /// + /// Some architectures like GPUs have 3 element loads and stores, and in case + /// of uneven vector size it can be more efficient to use one 3 element + /// load or store instead of several scalar loads even of of a wider type. + /// This callback tells if a particular widening is profitable. + virtual bool isVectorToScalarLoadStoreWidenBeneficial(unsigned Width, + EVT WidenVT, const MemSDNode *N) const { + return true; + } + /// \brief Return if the target supports combining a /// chain like: /// \code Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -2509,11 +2509,12 @@ // TLI: Target lowering used to determine legal types. // Width: Width left need to load/store. // WidenVT: The widen vector type to load to/store from +// N: Load or store SDNode. // Align: If 0, don't allow use of a wider type // WidenEx: If Align is not 0, the amount additional we can load/store from. static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI, - unsigned Width, EVT WidenVT, + unsigned Width, EVT WidenVT, const MemSDNode *N, unsigned Align = 0, unsigned WidenEx = 0) { EVT WidenEltVT = WidenVT.getVectorElementType(); unsigned WidenWidth = WidenVT.getSizeInBits(); @@ -2527,18 +2528,21 @@ // See if there is larger legal integer than the element type to load/store unsigned VT; - for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; - VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { - EVT MemVT((MVT::SimpleValueType) VT); - unsigned MemVTWidth = MemVT.getSizeInBits(); - if (MemVT.getSizeInBits() <= WidenEltWidth) - break; - if (TLI.isTypeLegal(MemVT) && (WidenWidth % MemVTWidth) == 0 && - isPowerOf2_32(WidenWidth / MemVTWidth) && - (MemVTWidth <= Width || - (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) { - RetVT = MemVT; - break; + if (TLI.isVectorToScalarLoadStoreWidenBeneficial(Width, WidenVT, N)) { + for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE; + VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) { + EVT MemVT((MVT::SimpleValueType) VT); + unsigned MemVTWidth = MemVT.getSizeInBits(); + if (MemVT.getSizeInBits() <= WidenEltWidth) + break; + if (TLI.isTypeLegal(MemVT) && (WidenWidth % MemVTWidth) == 0 && + isPowerOf2_32(WidenWidth / MemVTWidth) && + (MemVTWidth <= Width || + (Align != 0 && MemVTWidth <= AlignInBits && + MemVTWidth <= Width + WidenEx))) { + RetVT = MemVT; + break; + } } } @@ -2621,7 +2625,7 @@ unsigned LdAlign = (isVolatile) ? 0 : Align; // Allow wider loads // Find the vector type that can load from. - EVT NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LdAlign, WidthDiff); + EVT NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LD, LdAlign, WidthDiff); int NewVTWidth = NewVT.getSizeInBits(); SDValue LdOp = DAG.getLoad(NewVT, dl, Chain, BasePtr, LD->getPointerInfo(), isVolatile, isNonTemporal, isInvariant, Align, @@ -2665,7 +2669,7 @@ SDValue L; if (LdWidth < NewVTWidth) { // Our current type we are using is too large, find a better size - NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LdAlign, WidthDiff); + NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LD, LdAlign, WidthDiff); NewVTWidth = NewVT.getSizeInBits(); L = DAG.getLoad(NewVT, dl, Chain, BasePtr, LD->getPointerInfo().getWithOffset(Offset), isVolatile, @@ -2827,7 +2831,7 @@ unsigned Offset = 0; // offset from base to store while (StWidth != 0) { // Find the largest vector type we can store with - EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); + EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT, ST); unsigned NewVTWidth = NewVT.getSizeInBits(); unsigned Increment = NewVTWidth / 8; if (NewVT.isVector()) {