diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4193,11 +4193,33 @@ } }; +// Folds transpose(broadcast()) into brodcast(). +struct FoldTransposedScalarBroadcast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto bcastOp = transposeOp.getVector().getDefiningOp(); + if (!bcastOp) + return failure(); + + auto srcVectorType = bcastOp.getSourceType().dyn_cast(); + if (!srcVectorType || srcVectorType.getNumElements() == 1) { + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultType(), bcastOp.getSource()); + return success(); + } + + return failure(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1304,3 +1304,27 @@ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32> return %shuffle : vector<4xi32> } + +// ----- + +// CHECK-LABEL: func @transpose_scalar_broadcast1 +// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) +// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32> +// CHECK: return %[[V]] : vector<1x8xf32> +func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> { + %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32> + %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32> + return %t : vector<1x8xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_scalar_broadcast2 +// CHECK-SAME: (%[[ARG:.+]]: f32) +// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32> +// CHECK: return %[[V]] : vector<1x8xf32> +func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> { + %bcast = vector.broadcast %value : f32 to vector<8x1xf32> + %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32> + return %t : vector<1x8xf32> +}