diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3175,27 +3175,31 @@ } }; -struct CastAwayBroadcastLeadingOneDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, + LogicalResult matchAndRewrite(BroadCastType broadcastOp, PatternRewriter &rewriter) const override { - VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType()); - if (newDstType == broadcastOp.getVectorType()) + VectorType dstType = + broadcastOp.getResult().getType().template dyn_cast(); + if (!dstType) + return failure(); + VectorType newDstType = trimLeadingOneDims(dstType); + if (newDstType == dstType) return failure(); Location loc = broadcastOp.getLoc(); - VectorType srcVecType = broadcastOp.getSourceType().dyn_cast(); + Value source = broadcastOp->getOperand(0); + VectorType srcVecType = source.getType().template dyn_cast(); if (srcVecType) srcVecType = trimLeadingOneDims(srcVecType); - Value source = broadcastOp.source(); - if (srcVecType && srcVecType != broadcastOp.getSourceType()) { + if (srcVecType && srcVecType != source.getType()) { source = rewriter.create(loc, srcVecType, source); } Value newBroadcastOp = - rewriter.create(loc, newDstType, source); - rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), newBroadcastOp); + rewriter.create(loc, newDstType, source); + rewriter.replaceOpWithNewOp(broadcastOp, dstType, + newBroadcastOp); return success(); } }; @@ -3833,13 +3837,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { - patterns - .add( - patterns.getContext()); + patterns.add< + CastAwayExtractStridedSliceLeadingOneDim, + CastAwayInsertStridedSliceLeadingOneDim, + CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim, + CastAwayBroadcastLeadingOneDim, + CastAwayBroadcastLeadingOneDim, CastAwayElementwiseLeadingOneDim, + ShapeCastOpFolder>(patterns.getContext()); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -675,7 +675,7 @@ // CHECK-LABEL: func @cast_away_broadcast_leading_one_dims func @cast_away_broadcast_leading_one_dims( %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) -> - (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) { + (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) { // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32> // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32> @@ -686,7 +686,10 @@ // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32> // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32> %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32> - return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32> + // CHECK: splat %{{.*}} : vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32> + %3 = splat %arg1 : vector<1x1x4xf32> + return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32> } // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims