diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -161,11 +161,20 @@ /// A function that populates a `RewritePatternSet`. using PopulatePatternsFn = std::function; + /// A function that populates a `RewritePatternSet` with a specified benefit. + using PopulatePatternsWithBenefitFn = + std::function; /// Registers patterns with the specified identifier. The identifier should /// be prefixed with the dialect to which the patterns belong. void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn); + /// Registers patterns with the specified identifier. The identifier should + /// be prefixed with the dialect to which the patterns belong. The pattern + /// benefit is currently ignored. + void registerPatterns(StringRef identifier, + PopulatePatternsWithBenefitFn &&fn); + protected: friend class ApplyPatternsOp; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -223,6 +223,15 @@ patterns.try_emplace(attr, std::move(fn)); } +void transform::PatternRegistry::registerPatterns( + StringRef identifier, PopulatePatternsWithBenefitFn &&fn) { + StringAttr attr = builder.getStringAttr(identifier); + assert(!patterns.contains(attr) && "patterns identifier is already in use"); + patterns.try_emplace(attr, [f = move(fn)](RewritePatternSet &patternSet) { + f(patternSet, /*benefit=*/1); + }); +} + void transform::PatternRegistry::populatePatterns( StringAttr identifier, RewritePatternSet &patternSet) const { auto it = patterns.find(identifier); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -188,6 +189,34 @@ #define GET_OP_LIST #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" >(); + + addDialectDataInitializer( + [&](transform::PatternRegistry ®istry) { + registry.registerPatterns("vector.outer_product_lowering", + populateVectorOuterProductLoweringPatterns); + registry.registerPatterns("vector.broadcast_lowering", + populateVectorBroadcastLoweringPatterns); + registry.registerPatterns("vector.mask_op_lowering", + populateVectorMaskOpLoweringPatterns); + registry.registerPatterns("vector.shape_cast_lowering", + populateVectorShapeCastLoweringPatterns); + registry.registerPatterns( + "vector.transfer_lowering", + [&](RewritePatternSet &set, PatternBenefit benefit) { + return populateVectorTransferLoweringPatterns( + set, /*maxTransferRank=*/std::nullopt, benefit); + }); + registry.registerPatterns( + "vector.transfer_permutation_map_lowering", + populateVectorTransferPermutationMapLoweringPatterns); + registry.registerPatterns("vector.scan_lowering", + populateVectorScanLoweringPatterns); + registry.registerPatterns("vector.vector_gather_lowering", + populateVectorGatherLoweringPatterns); + registry.registerPatterns( + "vector.mask_lowering_for_side_effecting_ops", + populateVectorMaskLoweringPatternsForSideEffectingOps); + }); } }; } // namespace