diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -63,6 +63,12 @@ /// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); +/// This is a pass that bufferizes only tensor dialect ops. +std::unique_ptr createLinalgComprehensiveTensorBufferizePass(); + +/// This is a pass that bufferizes only vector dialect ops. +std::unique_ptr createLinalgComprehensiveVectorBufferizePass(); + /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -80,6 +80,26 @@ let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } +def LinalgComprehensiveTensorBufferize : + FunctionPass<"linalg-comprehensive-tensor-bufferize"> { + let summary = "Bufferize (tensor into memref) tensor dialect ops."; + let description = [{ + This pass bufferizes only tensor dialect ops. + TODO: This will be moved to the bufferization dialect in the future. + }]; + let constructor = "mlir::createLinalgComprehensiveTensorBufferizePass()"; +} + +def LinalgComprehensiveVectorBufferize : + FunctionPass<"linalg-comprehensive-vector-bufferize"> { + let summary = "Bufferize (tensor into memref) vector dialect ops."; + let description = [{ + This pass bufferizes only vector dialect ops. + TODO: This will be moved to the bufferization dialect in the future. + }]; + let constructor = "mlir::createLinalgComprehensiveVectorBufferizePass()"; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -120,6 +120,42 @@ void runOnFunction() final; }; + +struct LinalgComprehensiveTensorBufferize + : public LinalgComprehensiveTensorBufferizeBase< + LinalgComprehensiveTensorBufferize> { + LinalgComprehensiveTensorBufferize() {} + + LinalgComprehensiveTensorBufferize( + const LinalgComprehensiveTensorBufferize &p) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void runOnFunction() override; +}; + +struct LinalgComprehensiveVectorBufferize + : public LinalgComprehensiveVectorBufferizeBase< + LinalgComprehensiveVectorBufferize> { + LinalgComprehensiveVectorBufferize() {} + + LinalgComprehensiveVectorBufferize( + const LinalgComprehensiveVectorBufferize &p) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void runOnFunction() override; +}; } // end namespace static void applyEnablingTransformations(Operation *op) { @@ -163,6 +199,25 @@ this->runCleanupPipeline("builtin.func", getOperation()); } +void LinalgComprehensiveTensorBufferize::runOnFunction() { + BufferizationOptions options; + options.addPostAnalysisStep(); + options.allowReturnMemref = true; + options.allowUnknownOps = true; + + if (failed(runComprehensiveBufferize(getOperation(), options))) + signalPassFailure(); +} + +void LinalgComprehensiveVectorBufferize::runOnFunction() { + BufferizationOptions options; + options.allowReturnMemref = true; + options.allowUnknownOps = true; + + if (failed(runComprehensiveBufferize(getOperation(), options))) + signalPassFailure(); +} + std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } @@ -170,3 +225,11 @@ std::unique_ptr mlir::createLinalgComprehensiveFunctionBufferizePass() { return std::make_unique(); } + +std::unique_ptr mlir::createLinalgComprehensiveTensorBufferizePass() { + return std::make_unique(); +} + +std::unique_ptr mlir::createLinalgComprehensiveVectorBufferizePass() { + return std::make_unique(); +}