Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3175,6 +3175,31 @@ } }; +struct CastAwayBrodcastLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType()); + if (newDstType == broadcastOp.getVectorType()) + return failure(); + Location loc = broadcastOp.getLoc(); + VectorType srcVecType = broadcastOp.getSourceType().dyn_cast(); + if (srcVecType) + srcVecType = trimLeadingOneDims(srcVecType); + Value source = broadcastOp.source(); + if (srcVecType && srcVecType != broadcastOp.getSourceType()) { + source = rewriter.create(loc, srcVecType, source); + } + Value newBroadcastOp = + rewriter.create(loc, newDstType, source); + rewriter.replaceOpWithNewOp( + broadcastOp, broadcastOp.getVectorType(), newBroadcastOp); + return success(); + } +}; + // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -3771,7 +3796,8 @@ patterns.add( + CastAwayTransferWriteLeadingOneDim, + CastAwayBrodcastLeadingOneDim, ShapeCastOpFolder>( patterns.getContext()); } Index: mlir/test/Dialect/Vector/vector-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transforms.mlir +++ mlir/test/Dialect/Vector/vector-transforms.mlir @@ -672,6 +672,23 @@ return } +// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims +func @cast_away_broadcast_leading_one_dims( + %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) -> + (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) { + // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> + %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32> + // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32> + %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32> + %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32> + return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32> +} + // CHECK-LABEL: func @bubble_down_bitcast_in_extract // CHECK-SAME: %[[SRC:.+]]: vector<4xf32> func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {