diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1179,13 +1179,10 @@ if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) return rewriter.notifyMatchFailure( warpOp, "Reduction vector dimension must match was size."); - // Only f32, i32, f16, i8 element types are supported. - if (!reductionOp.getType().isF32() && - !reductionOp.getType().isSignlessInteger(32) && - !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8)) + if (!reductionOp.getType().isIntOrFloat()) return rewriter.notifyMatchFailure( - warpOp, "Reduction distribution currently only supports 32bits, f16, " - "and i8 types."); + warpOp, "Reduction distribution currently only supports floats and " + "integer types."); int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); // Return vector that will be reduced from the WarpExecuteOnLane0Op.