diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -710,6 +710,14 @@ ArrayRef originalShape, ArrayRef distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl &delinearizedIds) { + // If the original shape and the distributed shape is the same, we don't + // distribute at all--every thread is handling the whole. For such case, we + // should not rely on lane IDs later. So just return an empty lane ID vector. + if (originalShape == distributedShape) { + delinearizedIds.clear(); + return true; + } + SmallVector sizes; for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { if (large % small != 0) @@ -794,8 +802,9 @@ warpOp.getLaneid(), delinearizedIds)) return rewriter.notifyMatchFailure( read, "cannot delinearize lane ID for distribution"); + assert(!delinearizedIds.empty() || map.getNumResults() == 0); - for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { + for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { AffineExpr d0, d1; bindDims(read.getContext(), d0, d1); auto indexExpr = std::get<0>(it).dyn_cast(); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1236,3 +1236,18 @@ // CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> // CHECK-PROP: return %[[CAST]] : vector<4xf32> +// ----- + +func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> { + %f0 = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) { + %1 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32> + vector.yield %1 : vector<1xf32> + } + return %r : vector<1xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_uniform_transfer_read +// CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index) +// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32> +// CHECK-PROP: return %[[READ]] : vector<1xf32>