diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -346,8 +346,9 @@ Location loc = writeOp.getLoc(); VectorType vecType = writeOp.getVectorType(); - // Only vector<1x> is supported at the moment. - if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) + // Only sink out vector of 1 element for now to not serialize large vector + // store. This can later be controlled by user. + if (vecType.getNumElements() != 1) return failure(); // Do not process warp ops that contain only TransferWriteOps. diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -109,22 +109,25 @@ // ----- // CHECK-D-LABEL: func @warp_extract( -// CHECK-D: %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) +// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1x1xf32>) // CHECK-D: "test.dummy_op" -// CHECK-D: vector.yield %{{.*}} : vector<1xf32> +// CHECK-D: "test.dummy_op" +// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<1xf32>, vector<1x1xf32> // CHECK-D: } // CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] { -// CHECK-D: vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32> +// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<1x1xf32> +// CHECK-D: } +// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] { +// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32> // CHECK-D: } -#map2 = affine_map<(d0)[s0] -> (d0 + s0)> - -func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) { +func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) { vector.warp_execute_on_lane_0(%laneid)[32] { - %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2> %c0 = arith.constant 0 : index %v = "test.dummy_op"() : () -> (vector<1xf32>) - vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2> + %v1 = "test.dummy_op"() : () -> (vector<1x1xf32>) + vector.transfer_write %v1, %arg1[%c0, %c0] : vector<1x1xf32>, memref<1024x1024xf32> + vector.transfer_write %v, %arg1[%c0, %c0] : vector<1xf32>, memref<1024x1024xf32> } return }