diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,8 @@ namespace mlir { +class MLIRContext; +class OwningRewritePatternList; class Pass; /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape @@ -25,6 +27,9 @@ /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); +/// Collects a set of patterns to rewrite ops within the Shape dialect. +void populateShapeRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -54,8 +54,10 @@ } // namespace void ShapeToShapeLowering::runOnFunction() { + MLIRContext &ctx = getContext(); + OwningRewritePatternList patterns; - patterns.insert(&getContext()); + populateShapeRewritePatterns(&ctx, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); @@ -64,6 +66,11 @@ signalPassFailure(); } +void mlir::populateShapeRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert(context); +} + std::unique_ptr mlir::createShapeToShapeLowering() { return std::make_unique(); }