Index: include/llvm/Target/TargetLowering.h
===================================================================
--- include/llvm/Target/TargetLowering.h
+++ include/llvm/Target/TargetLowering.h
@@ -242,6 +242,21 @@
     return true;
   }
 
+  /// isVectorToScalarLoadStoreWidenBeneficial() - Return true if the vector
+  /// load or store packing into a larger scalar type is beneficial.
+  /// Width:     Width to load/store.
+  /// WidenVT:   The widen vector type to load to/store from.
+  /// N:         Load or store SDNode.
+  ///
+  /// Some architectures like GPUs have 3 element loads and stores, and in case
+  /// of uneven vector size it can be more efficient to use one 3 element
+  /// load or store instead of several scalar loads even of of a wider type.
+  /// This callback tells if a particular widening is profitable.
+  virtual bool isVectorToScalarLoadStoreWidenBeneficial(unsigned Width,
+                 EVT WidenVT, const MemSDNode *N) const {
+    return true;
+  }
+
   /// \brief Return if the target supports combining a
   /// chain like:
   /// \code
Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
===================================================================
--- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -2509,11 +2509,12 @@
 //  TLI:       Target lowering used to determine legal types.
 //  Width:     Width left need to load/store.
 //  WidenVT:   The widen vector type to load to/store from
+//  N:         Load or store SDNode.
 //  Align:     If 0, don't allow use of a wider type
 //  WidenEx:   If Align is not 0, the amount additional we can load/store from.
 
 static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
-                       unsigned Width, EVT WidenVT,
+                       unsigned Width, EVT WidenVT, const MemSDNode *N,
                        unsigned Align = 0, unsigned WidenEx = 0) {
   EVT WidenEltVT = WidenVT.getVectorElementType();
   unsigned WidenWidth = WidenVT.getSizeInBits();
@@ -2527,18 +2528,21 @@
 
   // See if there is larger legal integer than the element type to load/store
   unsigned VT;
-  for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE;
-       VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) {
-    EVT MemVT((MVT::SimpleValueType) VT);
-    unsigned MemVTWidth = MemVT.getSizeInBits();
-    if (MemVT.getSizeInBits() <= WidenEltWidth)
-      break;
-    if (TLI.isTypeLegal(MemVT) && (WidenWidth % MemVTWidth) == 0 &&
-        isPowerOf2_32(WidenWidth / MemVTWidth) &&
-        (MemVTWidth <= Width ||
-         (Align!=0 && MemVTWidth<=AlignInBits && MemVTWidth<=Width+WidenEx))) {
-      RetVT = MemVT;
-      break;
+  if (TLI.isVectorToScalarLoadStoreWidenBeneficial(Width, WidenVT, N)) {
+    for (VT = (unsigned)MVT::LAST_INTEGER_VALUETYPE;
+         VT >= (unsigned)MVT::FIRST_INTEGER_VALUETYPE; --VT) {
+      EVT MemVT((MVT::SimpleValueType) VT);
+      unsigned MemVTWidth = MemVT.getSizeInBits();
+      if (MemVT.getSizeInBits() <= WidenEltWidth)
+        break;
+      if (TLI.isTypeLegal(MemVT) && (WidenWidth % MemVTWidth) == 0 &&
+          isPowerOf2_32(WidenWidth / MemVTWidth) &&
+          (MemVTWidth <= Width ||
+           (Align != 0 && MemVTWidth <= AlignInBits &&
+            MemVTWidth <= Width + WidenEx))) {
+        RetVT = MemVT;
+        break;
+      }
     }
   }
 
@@ -2621,7 +2625,7 @@
   unsigned LdAlign = (isVolatile) ? 0 : Align; // Allow wider loads
 
   // Find the vector type that can load from.
-  EVT NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LdAlign, WidthDiff);
+  EVT NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LD, LdAlign, WidthDiff);
   int NewVTWidth = NewVT.getSizeInBits();
   SDValue LdOp = DAG.getLoad(NewVT, dl, Chain, BasePtr, LD->getPointerInfo(),
                              isVolatile, isNonTemporal, isInvariant, Align,
@@ -2665,7 +2669,7 @@
     SDValue L;
     if (LdWidth < NewVTWidth) {
       // Our current type we are using is too large, find a better size
-      NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LdAlign, WidthDiff);
+      NewVT = FindMemType(DAG, TLI, LdWidth, WidenVT, LD, LdAlign, WidthDiff);
       NewVTWidth = NewVT.getSizeInBits();
       L = DAG.getLoad(NewVT, dl, Chain, BasePtr,
                       LD->getPointerInfo().getWithOffset(Offset), isVolatile,
@@ -2827,7 +2831,7 @@
   unsigned Offset = 0;  // offset from base to store
   while (StWidth != 0) {
     // Find the largest vector type we can store with
-    EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT);
+    EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT, ST);
     unsigned NewVTWidth = NewVT.getSizeInBits();
     unsigned Increment = NewVTWidth / 8;
     if (NewVT.isVector()) {