diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace mlir { +class GlobalCreator; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -31,6 +32,10 @@ /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); +void populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -81,6 +81,13 @@ }; } // namespace +void mlir::populateTensorConstantBufferizePatterns( + GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add(globalCreator, typeConverter, + patterns.getContext()); +} + namespace { struct TensorConstantBufferizePass : public TensorConstantBufferizeBase { @@ -94,7 +101,7 @@ ConversionTarget target(*context); target.addLegalDialect(); - patterns.add(globals, typeConverter, context); + populateTensorConstantBufferizePatterns(globals, typeConverter, patterns); target.addDynamicallyLegalOp( [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns))))