diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -896,10 +896,19 @@ Location loc = broadcastOp.getLoc(); auto destVecType = cast(warpOp->getResultTypes()[operandNumber]); + Value broadcastSrc = broadcastOp.getSource(); + Type broadcastSrcType = broadcastSrc.getType(); + + // Check that the broadcast actually spans a set of values uniformly across + // all threads. In other words, check that each thread can reconstruct + // their own broadcast. + // For that we simply check that the broadcast we want to build makes sense. + if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != + vector::BroadcastableToResult::Success) + return failure(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {broadcastOp.getSource()}, - {broadcastOp.getSource().getType()}, newRetIndices); + rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); Value broadcasted = rewriter.create( loc, destVecType, newWarpOp->getResult(newRetIndices[0])); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1152,3 +1152,24 @@ } return %18 : vector<2xf32> } + +// ----- + +// Check that we don't fold vector.broadcast when each thread doesn't get the +// same value. + +// CHECK-PROP-LABEL: func @dont_fold_vector_broadcast( +// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>) +// CHECK-PROP: %[[some_def:.*]] = "some_def" +// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32> +// CHECK-PROP: vector.yield %[[broadcast]] : vector<1x64xf32> +// CHECK-PROP: vector.print %[[r]] : vector<1x2xf32> +func.func @dont_fold_vector_broadcast(%laneid: index) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = vector.broadcast %0 : vector<64xf32> to vector<1x64xf32> + vector.yield %1 : vector<1x64xf32> + } + vector.print %r : vector<1x2xf32> + return +}