diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -159,23 +159,29 @@ } }; -/// Collect a set of transformation patterns that are related to contracting -/// or expanding vector operations: -/// ContractionOpLowering, -/// ShapeCastOp2DDownCastRewritePattern, -/// ShapeCastOp2DUpCastRewritePattern -/// BroadcastOpLowering, -/// OuterproductOpLowering -/// These transformation express higher level vector ops in terms of more -/// elementary extraction, insertion, reduction, product, and broadcast ops. +/// Collects patterns to progressively lower vector.broadcast ops on high-D +/// vectors to low-D vector ops. +void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); + +/// Collects patterns to progressively lower vector contraction ops on high-D +/// into low-D reduction and product ops. void populateVectorContractLoweringPatterns( RewritePatternSet &patterns, - VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); + VectorTransformsOptions options = VectorTransformsOptions()); + +/// Collects patterns to progressively lower vector mask ops into elementary +/// selection and insertion ops. +void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns); + +/// Collects patterns to progressively lower vector.shape_cast ops on high-D +/// vectors into 1-D/2-D vector ops by generating data movement extract/insert +/// ops. +void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns); /// Insert TransposeLowering patterns into extraction/insertion. void populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, - VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions()); + VectorTransformsOptions options = VectorTransformsOptions()); /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -62,7 +62,10 @@ { RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); + populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns); + populateVectorMaskOpLoweringPatterns(patterns); + populateVectorShapeCastLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3847,27 +3847,35 @@ BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); } +void mlir::vector::populateVectorBroadcastLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void mlir::vector::populateVectorMaskOpLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void mlir::vector::populateVectorShapeCastLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::vector::populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions parameters) { - // clang-format off - patterns.add(patterns.getContext()); - patterns.add(parameters, patterns.getContext()); - // clang-format on + RewritePatternSet &patterns, VectorTransformsOptions options) { + patterns.add(patterns.getContext()); + patterns.add(options, + patterns.getContext()); } void mlir::vector::populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, - VectorTransformsOptions vectorTransformOptions) { - patterns.add(vectorTransformOptions, - patterns.getContext()); + RewritePatternSet &patterns, VectorTransformsOptions options) { + patterns.add(options, patterns.getContext()); } void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -112,8 +112,11 @@ // Programmatic controlled lowering of vector.contract only. RewritePatternSet vectorContractLoweringPatterns(context); + populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns); populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, vectorTransformOptions); + populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns); + populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns); populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns, vectorTransformOptions); (void)applyPatternsAndFoldGreedily(module, diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -164,7 +164,10 @@ if (lowerToFlatTranspose) transposeLowering = VectorTransposeLowering::Flat; VectorTransformsOptions options{contractLowering, transposeLowering}; + populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, options); + populateVectorMaskOpLoweringPatterns(patterns); + populateVectorShapeCastLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }