diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1135,12 +1135,13 @@ if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Reduction vector dimension must match was size."); - // Only f32 and i32 element types are supported. + // Only f32, i32, f16, i8 element types are supported. if (!reductionOp.getType().isF32() && - !reductionOp.getType().isSignlessInteger(32)) + !reductionOp.getType().isSignlessInteger(32) && + !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8)) return rewriter.notifyMatchFailure( - warpOp, - "Reduction distribution currently only supports 32bits types."); + warpOp, "Reduction distribution currently only supports 32bits, f16, " + "and i8 types."); int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); // Return vector that will be reduced from the WarpExecuteOnLane0Op. @@ -1157,13 +1158,11 @@ rewriter, warpOp, yieldValues, retTypes, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); + // Obtain data to reduce for a single lane. Value laneValVec = newWarpOp.getResult(newRetIndices[0]); - // First reduce on a single thread. - Value perLaneReduction = rewriter.create( - reductionOp.getLoc(), reductionOp.getKind(), laneValVec); - // Then distribute across threads. + // Distribute and reduce across threads. Value fullReduce = - distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction, + distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, reductionOp.getKind(), newWarpOp.getWarpSize()); if (reductionOp.getAcc()) { fullReduce = vector::makeArithReduction( diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -686,7 +686,8 @@ static Value warpReduction(Location loc, OpBuilder &builder, Value input, CombiningKind kind, uint32_t size) { - Value laneVal = input; + // First reduce on a single thread to get per lane reduction value. + Value laneVal = builder.create(loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { Value shuffled = builder