diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -712,6 +712,9 @@ if (!operand) return failure(); auto read = operand->get().getDefiningOp(); + // Don't duplicate transfer_read ops when distributing. + if (!read.getResult().hasOneUse()) + return failure(); unsigned operandIndex = operand->getOperandNumber(); Value distributedVal = warpOp.getResult(operandIndex); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -650,3 +650,22 @@ vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32> return } + +// ----- + +// CHECK-PROP: func @dont_duplicate_read +func.func @dont_duplicate_read( + %laneid: index, %src: memref<1024xf32>) -> vector<1xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 +// CHECK-PROP: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { +// CHECK-PROP-NEXT: vector.transfer_read +// CHECK-PROP-NEXT: "blocking_use" +// CHECK-PROP-NEXT: vector.yield + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { + %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<32xf32> + "blocking_use"(%2) : (vector<32xf32>) -> () + vector.yield %2 : vector<32xf32> + } + return %r : vector<1xf32> +}