diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -67,6 +67,13 @@ ArrayRef{laneId}); } + /// Create a store during the process of distributing the + /// `vector.warp_execute_on_thread_0` op. + /// Vector distribution assumes the following convention regarding the + /// temporary buffers that are created to transition values. This **must** + /// be properly specified in the `options.warpAllocationFn`: + /// 1. scalars of type T transit through a memref<1xT>. + /// 2. vectors of type V transit through a memref Operation *buildStore(RewriterBase &b, Location loc, Value val, Value buffer) { assert((val == distributedVal || val == sequentialVal) && @@ -75,7 +82,12 @@ // Vector case must use vector::TransferWriteOp which will later lower to // vector.store of memref.store depending on further lowerings. if (val.getType().isa()) { - SmallVector indices(sequentialVectorType.getRank(), zero); + int64_t rank = sequentialVectorType.getRank(); + if (rank == 0) { + return b.create(loc, val, buffer, ValueRange{}, + ArrayRef{}); + } + SmallVector indices(rank, zero); auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType); assert(maybeDistributedDim && "must be able to deduce distributed dim"); if (val == distributedVal) @@ -90,17 +102,41 @@ return b.create(loc, val, buffer, zero); } + /// Create a load during the process of distributing the + /// `vector.warp_execute_on_thread_0` op. + /// Vector distribution assumes the following convention regarding the + /// temporary buffers that are created to transition values. This **must** + /// be properly specified in the `options.warpAllocationFn`: + /// 1. scalars of type T transit through a memref<1xT>. + /// 2. vectors of type V transit through a memref + /// + /// When broadcastMode is true, the load is not distributed to account for + /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op. + /// + /// Example: + /// + /// ``` + /// %r = vector.warp_execute_on_lane_0(...) -> (f32) { + /// vector.yield %cst : f32 + /// } + /// // Both types are f32. The constant %cst is broadcasted to all lanes. + /// ``` + /// This behavior described in more detail in the documentation of the op. Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer, bool broadcastMode = false) { - // When broadcastMode is true, this is a broadcast. - // E.g.: - // %r = vector.warp_execute_on_lane_0(...) -> (f32) { - // vector.yield %cst : f32 - // } - // Both types are f32. The constant %cst is broadcasted to all lanes. - // This is described in more detail in the documentation of the op. - if (broadcastMode) + if (broadcastMode) { + // Broadcast mode may occur for either scalar or vector operands. + auto vectorType = type.dyn_cast(); + auto shape = buffer.getType().cast(); + if (vectorType) { + SmallVector inBounds(shape.getRank(), true); + return b.create( + loc, vectorType, buffer, + /*indices=*/SmallVector(shape.getRank(), zero), + ArrayRef(inBounds.begin(), inBounds.end())); + } return b.create(loc, buffer, zero); + } // Other cases must be vector atm. // Vector case must use vector::TransferReadOp which will later lower to @@ -328,8 +364,10 @@ helper.buildStore(rewriter, loc, distributedVal, buffer); // Load sequential vector from buffer, inside the ifOp. rewriter.setInsertionPointToStart(ifOp.thenBlock()); - bbArgReplacements.push_back( - helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); + bool broadcastMode = + (sequentialVal.getType() == distributedVal.getType()); + bbArgReplacements.push_back(helper.buildLoad( + rewriter, loc, sequentialVal.getType(), buffer, broadcastMode)); } // Step 3. Insert sync after all the stores and before all the loads. diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -691,3 +691,46 @@ // CHECK-PROP: return %[[SINGLE_RES]], %[[SINGLE_RES]] : vector<1xf32>, vector<1xf32> return %r#0, %r#1 : vector<1xf32>, vector<1xf32> } + +// ----- + +// CHECK-SCF-IF: func @warp_execute_has_broadcast_semantics +func.func @warp_execute_has_broadcast_semantics(%laneid: index, %s0: f32, %v0: vector, %v1: vector<1xf32>, %v2: vector<1x1xf32>) + -> (f32, vector, vector<1xf32>, vector<1x1xf32>) { + // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK-SCF-IF: scf.if{{.*}}{ + %r:4 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%s0, %v0, %v1, %v2 : f32, vector, vector<1xf32>, vector<1x1xf32>) -> (f32, vector, vector<1xf32>, vector<1x1xf32>) { + ^bb0(%bs0: f32, %bv0: vector, %bv1: vector<1xf32>, %bv2: vector<1x1xf32>): + + // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32> + // CHECK-SCF-IF: vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32> + // CHECK-SCF-IF: vector.transfer_read {{.*}}[]{{.*}} : memref, vector + // CHECK-SCF-IF: memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3> + // CHECK-SCF-IF: "some_def_0"(%{{.*}}) : (f32) -> f32 + // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector) -> vector + // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> + // CHECK-SCF-IF: "some_def_1"(%{{.*}}) : (vector<1x1xf32>) -> vector<1x1xf32> + // CHECK-SCF-IF: memref.store {{.*}}[%[[C0]]] : memref<1xf32, 3> + // CHECK-SCF-IF: vector.transfer_write {{.*}}[] : vector, memref + // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]]] {in_bounds = [true]} : vector<1xf32>, memref<1xf32, 3> + // CHECK-SCF-IF: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x1xf32>, memref<1x1xf32, 3> + + %rs0 = "some_def_0"(%bs0) : (f32) -> f32 + %rv0 = "some_def_1"(%bv0) : (vector) -> vector + %rv1 = "some_def_1"(%bv1) : (vector<1xf32>) -> vector<1xf32> + %rv2 = "some_def_1"(%bv2) : (vector<1x1xf32>) -> vector<1x1xf32> + + // CHECK-SCF-IF-NOT: vector.yield + vector.yield %rs0, %rv0, %rv1, %rv2 : f32, vector, vector<1xf32>, vector<1x1xf32> + } + + // CHECK-SCF-IF: gpu.barrier + // CHECK-SCF-IF: %[[RV2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]]{{.*}} {in_bounds = [true, true]} : memref<1x1xf32, 3>, vector<1x1xf32> + // CHECK-SCF-IF: %[[RV1:.*]] = vector.transfer_read {{.*}}[%[[C0]]]{{.*}} {in_bounds = [true]} : memref<1xf32, 3>, vector<1xf32> + // CHECK-SCF-IF: %[[RV0:.*]] = vector.transfer_read {{.*}}[]{{.*}} : memref, vector + // CHECK-SCF-IF: %[[RS0:.*]] = memref.load {{.*}}[%[[C0]]] : memref<1xf32, 3> + // CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector, vector<1xf32>, vector<1x1xf32> + return %r#0, %r#1, %r#2, %r#3 : f32, vector, vector<1xf32>, vector<1x1xf32> +}