diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -56,11 +56,11 @@ /// This pass implements a cross-dialect bufferization approach and performs an /// analysis to determine which op operands and results may be bufferized in the -/// same buffers. The analysis is performed on topologically sorted CallOp and -/// FuncOp within a module. It provides analyses and bufferization across -/// function boundaries. Within a function boundary, the analysis is performed -/// on SSA use-def chains starting from function operands that are annotated -/// with the 'inplaceable' attribute. +/// same buffers. +std::unique_ptr createLinalgComprehensiveFunctionBufferizePass(); + +/// This pass is an extension of the comprehensive function bufferization. It +/// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); /// Create a pass to convert Linalg operations which work on tensors to use diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -22,17 +22,43 @@ let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; } +def LinalgComprehensiveFunctionBufferize : + FunctionPass<"linalg-comprehensive-function-bufferize"> { + let summary = "Bufferize (tensor into memref) for a function."; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on SSA use-def chains starting from + function operands that are annotated with the 'inplaceable' attribute. + }]; + let options = [ + Option<"testAnalysisOnly", "test-analysis-only", "bool", + /*default=*/"false", + "Only runs inplaceability analysis (for testing purposes only)">, + Option<"allowReturnMemref", "allow-return-memref", "bool", + /*default=*/"false", + "Allows the return of memrefs (for testing purposes only)">, + Option<"allowUnknownOps", "allow-unknown-ops", "bool", + /*default=*/"false", + "Allows unknown (not bufferizable) ops in the input IR.">, + Option<"useAlloca", "use-alloca", "bool", + /*default=*/"false", + "Use stack allocations for memrefs (for testing purposes only)">, + Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", + /*default=*/"0", + "Analyze ops in random order with a given seed (fuzzer)"> + ]; + let constructor = "mlir::createLinalgComprehensiveFunctionBufferizePass()"; +} + def LinalgComprehensiveModuleBufferize : Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> { let summary = "Bufferize (tensor into memref) for a Module."; let description = [{ - This pass implements a cross-dialect bufferization approach and performs an - analysis to determine which op operands and results may be bufferized in the - same buffers. The analysis is performed on topologically sorted CallOp and + This pass is an extension of linalg-comprehensive-function-bufferize. It + performs an inter-function analysis on topologically sorted CallOp and FuncOp within a module. It provides analyses and bufferization across - function boundaries. Within a function boundary, the analysis is performed - on SSA use-def chains starting from function operands that are annotated - with the 'inplaceable' attribute. + function boundaries. }]; let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -29,10 +29,71 @@ using namespace mlir::linalg; using namespace mlir::linalg::comprehensive_bufferize; +static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + namespace { +/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are +/// mostly identical. +template +struct BufferizePassHelper : public PassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + affine_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void prepareOptions(BufferizationOptions &options) { + if (this->useAlloca) { + options.allocationFns->allocationFn = allocationFnUsingAlloca; + options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, + Value v) {}; + } + // TODO: Change to memref::CopyOp (default memCpyFn). + options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, + Value to) { + b.create(loc, from, to); + }; + + // Enable InitTensorOp elimination. + options.addPostAnalysisStep< + linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + // TODO: Find a way to enable this step automatically when bufferizing + // tensor dialect ops. + options.addPostAnalysisStep(); + options.addPostAnalysisStep(); + + options.allowReturnMemref = this->allowReturnMemref; + options.allowUnknownOps = this->allowUnknownOps; + options.analysisFuzzerSeed = this->analysisFuzzerSeed; + options.testAnalysisOnly = this->testAnalysisOnly; + } + + void runCleanupPipeline(StringRef name, Operation *op) { + OpPassManager cleanupPipeline(name); + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + (void)this->runPipeline(cleanupPipeline, op); + } +}; + struct LinalgComprehensiveModuleBufferize - : public LinalgComprehensiveModuleBufferizeBase< - LinalgComprehensiveModuleBufferize> { + : public BufferizePassHelper> { LinalgComprehensiveModuleBufferize() {} LinalgComprehensiveModuleBufferize( @@ -41,62 +102,35 @@ void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - affine_ext::registerBufferizableOpInterfaceExternalModels(registry); - arith_ext::registerBufferizableOpInterfaceExternalModels(registry); - bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); - linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); - scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + BufferizePassHelper>::getDependentDialects(registry); + // Module bufferization requires additional external models for FuncOp etc. + registry.insert(); std_ext::registerBufferizableOpInterfaceExternalModels(registry); - tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); - vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; -} // end namespace -static void applyEnablingTransformations(ModuleOp moduleOp) { - RewritePatternSet patterns(moduleOp.getContext()); - patterns.add(moduleOp.getContext()); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); -} +struct LinalgComprehensiveFunctionBufferize + : public BufferizePassHelper> { + LinalgComprehensiveFunctionBufferize() {} -static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - return allocated; + LinalgComprehensiveFunctionBufferize( + const LinalgComprehensiveFunctionBufferize &p) {} + + void runOnFunction() final; +}; +} // end namespace + +static void applyEnablingTransformations(Operation *op) { + RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } void LinalgComprehensiveModuleBufferize::runOnOperation() { BufferizationOptions options; - if (useAlloca) { - options.allocationFns->allocationFn = allocationFnUsingAlloca; - options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, - Value v) {}; - } - // TODO: Change to memref::CopyOp (default memCpyFn). - options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, - Value to) { - b.create(loc, from, to); - }; - - options.allowReturnMemref = allowReturnMemref; - options.allowUnknownOps = allowUnknownOps; - options.analysisFuzzerSeed = analysisFuzzerSeed; - options.testAnalysisOnly = testAnalysisOnly; - - // Enable InitTensorOp elimination. - options.addPostAnalysisStep< - linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); - // TODO: Find a way to enable this step automatically when bufferizing tensor - // dialect ops. - options.addPostAnalysisStep(); - options.addPostAnalysisStep(); + this->prepareOptions(options); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -109,13 +143,30 @@ if (options.testAnalysisOnly) return; - OpPassManager cleanupPipeline("builtin.module"); - cleanupPipeline.addPass(createCanonicalizerPass()); - cleanupPipeline.addPass(createCSEPass()); - cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); - (void)runPipeline(cleanupPipeline, moduleOp); + this->runCleanupPipeline("builtin.module", moduleOp); +} + +void LinalgComprehensiveFunctionBufferize::runOnFunction() { + BufferizationOptions options; + this->prepareOptions(options); + + applyEnablingTransformations(getOperation()); + + if (failed(runComprehensiveBufferize(getOperation(), options))) { + signalPassFailure(); + return; + } + + if (options.testAnalysisOnly) + return; + + this->runCleanupPipeline("builtin.func", getOperation()); } std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } + +std::unique_ptr mlir::createLinalgComprehensiveFunctionBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s + +// Run fuzzer with different seeds. +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null + +// CHECK-LABEL: func @use_tensor_func_arg( +// CHECK-SAME: %[[A:.*]]: tensor +func @use_tensor_func_arg(%A : tensor) -> (vector<4xf32>) { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]] + %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> + + // CHECK: return %[[res]] + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @return_tensor( +// CHECK-SAME: %[[A:.*]]: tensor +func @return_tensor(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = tensor.dim %[[A]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[A_memref]], %[[alloc]]) + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -979,3 +979,32 @@ } return %1: tensor } + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +}