diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -173,7 +173,6 @@ /// ShapeCastOp2DDownCastRewritePattern, /// ShapeCastOp2DUpCastRewritePattern /// BroadcastOpLowering, -/// TransposeOpLowering /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more /// elementary extraction, insertion, reduction, product, and broadcast ops. @@ -181,6 +180,11 @@ RewritePatternSet &patterns, VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); +/// Insert TransposeLowering patterns into extraction/insertion. +void populateVectorTransposeLoweringPatterns( + RewritePatternSet &patterns, + VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -64,6 +64,7 @@ populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorSlicesLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns); + populateVectorTransposeLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3823,13 +3823,19 @@ ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(patterns.getContext()); - patterns.add(parameters, patterns.getContext()); // clang-format on } +void mlir::vector::populateVectorTransposeLoweringPatterns( + RewritePatternSet &patterns, + VectorTransformsOptions vectorTransformOptions) { + patterns.add(vectorTransformOptions, + patterns.getContext()); +} + void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns) { patterns diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -109,6 +109,8 @@ RewritePatternSet vectorContractLoweringPatterns(context); populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, vectorTransformsOptions); + populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns, + vectorTransformsOptions); (void)applyPatternsAndFoldGreedily(module, std::move(vectorContractLoweringPatterns)); diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -140,6 +140,7 @@ transposeLowering = VectorTransposeLowering::Flat; VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, options); + populateVectorTransposeLoweringPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };