diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4757,9 +4757,10 @@ }; /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. -/// This only applies when the shape of the broadcast source is a suffix of the -/// shape of the result (i.e. when broadcast without reshape is expressive -/// enough to capture the result in a single op). +/// This only applies when the shape of the broadcast source +/// 1. is a suffix of the shape of the result (i.e. when broadcast without +/// reshape is expressive enough to capture the result in a single op), or +/// 2. has the same element count as the shape cast result. class ShapeCastBroadcastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -4771,23 +4772,35 @@ if (!broadcastOp) return failure(); - auto broadcastSourceVectorType = - llvm::dyn_cast(broadcastOp.getSourceType()); - auto broadcastSourceShape = broadcastSourceVectorType - ? broadcastSourceVectorType.getShape() - : ArrayRef{}; - auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape(); - - // Bail if `broadcastSourceShape` is not a suffix of the result. - bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back( - broadcastSourceShape.size())); - if (!isSuffix) - return failure(); + ArrayRef broadcastSourceShape; + if (auto srcType = dyn_cast(broadcastOp.getSourceType())) + broadcastSourceShape = srcType.getShape(); + ArrayRef shapeCastTargetShape = + shapeCastOp.getResultVectorType().getShape(); - rewriter.replaceOpWithNewOp( - shapeCastOp, shapeCastOp.getResultVectorType(), - broadcastOp.getSource()); - return success(); + // If `broadcastSourceShape` is a suffix of the result, we can just replace + // with a broadcast to the final shape. + if (broadcastSourceShape == + shapeCastTargetShape.take_back(broadcastSourceShape.size())) { + rewriter.replaceOpWithNewOp( + shapeCastOp, shapeCastOp.getResultVectorType(), + broadcastOp.getSource()); + return success(); + } + + // Otherwise, if the final result has the same element count, we can replace + // with a shape cast. + if (auto srcType = dyn_cast(broadcastOp.getSourceType())) { + if (srcType.getNumElements() == + shapeCastOp.getResultVectorType().getNumElements()) { + rewriter.replaceOpWithNewOp( + shapeCastOp, shapeCastOp.getResultVectorType(), + broadcastOp.getSource()); + return success(); + } + } + + return failure(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -714,10 +714,10 @@ // ----- -// CHECK-LABEL: func @canonicalize_broadcast_shapecast +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast // CHECK: vector.broadcast // CHECK-NOT: vector.shape_cast -func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3xf32> { +func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -> vector<8x3xf32> { %0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32> %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32> return %1 : vector<8x3xf32> @@ -725,6 +725,17 @@ // ----- +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<3x4xf32> to vector<1x12xf32> +func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>) -> vector<1x12xf32> { + %0 = vector.broadcast %arg0 : vector<3x4xf32> to vector<1x1x3x4xf32> + %1 = vector.shape_cast %0 : vector<1x1x3x4xf32> to vector<1x12xf32> + return %1 : vector<1x12xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func.func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = arith.constant 0 : index