diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1135,12 +1135,13 @@
     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
       return rewriter.notifyMatchFailure(
           warpOp, "Reduction vector dimension must match was size.");
-    // Only f32 and i32 element types are supported.
+    // Only f32, i32, f16, i8 element types are supported.
     if (!reductionOp.getType().isF32() &&
-        !reductionOp.getType().isSignlessInteger(32))
+        !reductionOp.getType().isSignlessInteger(32) &&
+        !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
       return rewriter.notifyMatchFailure(
-          warpOp,
-          "Reduction distribution currently only supports 32bits types.");
+          warpOp, "Reduction distribution currently only supports 32bits, f16, "
+                  "and i8 types.");
 
     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
@@ -1157,13 +1158,11 @@
         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
 
+    // Obtain data to reduce for a single lane.
     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
-    // First reduce on a single thread.
-    Value perLaneReduction = rewriter.create<vector::ReductionOp>(
-        reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
-    // Then distribute across threads.
+    // Distribute and reduce across threads.
     Value fullReduce =
-        distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
+        distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
                                reductionOp.getKind(), newWarpOp.getWarpSize());
     if (reductionOp.getAcc()) {
       fullReduce = vector::makeArithReduction(
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -686,7 +686,8 @@
 
 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
                            CombiningKind kind, uint32_t size) {
-  Value laneVal = input;
+  // First reduce on a single thread to get per lane reduction value.
+  Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
   // Parallel reduction using butterfly shuffles.
   for (uint64_t i = 1; i < size; i <<= 1) {
     Value shuffled = builder