diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -67,7 +67,7 @@ /// Run the post analysis step. This function may modify the IR, but must keep /// `aliasInfo` (inside `state`) consistent. Newly created operations and /// operations that should be re-analyzed must be stored in `newOps`. - virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, + virtual LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) = 0; }; @@ -296,8 +296,8 @@ /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. struct BufferizationState { - BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) - : aliasInfo(moduleOp), options(options) {} + BufferizationState(Operation *op, const BufferizationOptions &options) + : aliasInfo(op), options(options) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -19,10 +19,13 @@ struct BufferizationOptions; struct BufferizationState; +/// Bufferize the given op. +LogicalResult runComprehensiveBufferize(Operation *op, + const BufferizationOptions &options); + /// Bufferize the given function. Does not bufferize the function boundary. -// TODO: This function is meant to be called from ModuleBufferize and not can -// not yet be called standalone. -LogicalResult runComprehensiveBufferize(FuncOp funcOp, +/// Reuses an existing BufferizationState object. +LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, BufferizationState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -23,7 +23,7 @@ namespace linalg_ext { struct InitTensorEliminationStep : public PostAnalysisStep { - /// Try to eliminate InitTensorOps inside `funcOp`. + /// Try to eliminate InitTensorOps inside `op`. /// /// * `rewriteFunc` generates the replacement for the InitTensorOp. /// * Only InitTensorOps that are anchored on a matching OpOperand as per @@ -34,18 +34,18 @@ /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. LogicalResult eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps); }; -/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// Try to eliminate InitTensorOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -20,7 +20,7 @@ namespace tensor_ext { struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -56,11 +56,11 @@ /// This pass implements a cross-dialect bufferization approach and performs an /// analysis to determine which op operands and results may be bufferized in the -/// same buffers. The analysis is performed on topologically sorted CallOp and -/// FuncOp within a module. It provides analyses and bufferization across -/// function boundaries. Within a function boundary, the analysis is performed -/// on SSA use-def chains starting from function operands that are annotated -/// with the 'inplaceable' attribute. +/// same buffers. +std::unique_ptr createLinalgComprehensiveFunctionBufferizePass(); + +/// This pass is an extension of the comprehensive function bufferization. It +/// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); /// Create a pass to convert Linalg operations which work on tensors to use diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -22,17 +22,43 @@ let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; } +def LinalgComprehensiveFunctionBufferize : + FunctionPass<"linalg-comprehensive-function-bufferize"> { + let summary = "Bufferize (tensor into memref) for a function."; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on SSA use-def chains starting from + function operands that are annotated with the 'inplaceable' attribute. + }]; + let options = [ + Option<"testAnalysisOnly", "test-analysis-only", "bool", + /*default=*/"false", + "Only runs inplaceability analysis (for testing purposes only)">, + Option<"allowReturnMemref", "allow-return-memref", "bool", + /*default=*/"false", + "Allows the return of memrefs (for testing purposes only)">, + Option<"allowUnknownOps", "allow-unknown-ops", "bool", + /*default=*/"false", + "Allows unknown (not bufferizable) ops in the input IR.">, + Option<"useAlloca", "use-alloca", "bool", + /*default=*/"false", + "Use stack allocations for memrefs (for testing purposes only)">, + Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", + /*default=*/"0", + "Analyze ops in random order with a given seed (fuzzer)"> + ]; + let constructor = "mlir::createLinalgComprehensiveFunctionBufferizePass()"; +} + def LinalgComprehensiveModuleBufferize : Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> { let summary = "Bufferize (tensor into memref) for a Module."; let description = [{ - This pass implements a cross-dialect bufferization approach and performs an - analysis to determine which op operands and results may be bufferized in the - same buffers. The analysis is performed on topologically sorted CallOp and + This pass is an extension of linalg-comprehensive-function-bufferize. It + performs an inter-function analysis on topologically sorted CallOp and FuncOp within a module. It provides analyses and bufferization across - function boundaries. Within a function boundary, the analysis is performed - on SSA use-def chains starting from function operands that are annotated - with the 'inplaceable' attribute. + function boundaries. }]; let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -465,9 +465,11 @@ auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + bool isFuncOp = isa(op); // No tensor results or operands: Simply bufferize all nested ops. - if (!hasTensorResult && !hasTensorOperand) { + // Always pass through FuncOps. + if (!hasTensorResult && !hasTensorOperand && !isFuncOp) { for (Region ®ion : op->getRegions()) if (failed(bufferize(®ion, state))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -624,10 +624,10 @@ /// Assert that the current bufferization decisions are consistent. static LogicalResult -checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo, +checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, const BufferizationAliasInfo &aliasInfo) { Operation *inconsistentOp = nullptr; - WalkResult walkResult = funcOp.walk([&](Operation *op) { + WalkResult walkResult = op->walk([&](Operation *op) { if (auto bufferizableOp = dyn_cast(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { @@ -669,21 +669,24 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - FuncOp funcOp, const BufferizationOptions &options, + Operation *op, const BufferizationOptions &options) { + BufferizationState state(op, options); + return runComprehensiveBufferize(op, options, state); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, BufferizationState &state) { - DominanceInfo domInfo(funcOp); + DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - if (funcOp.body().empty()) - return success(); - - if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo))) return failure(); // Collect ops so we can build our own reverse traversal. SmallVector ops; - funcOp.walk([&](Operation *op) { + op->walk([&](Operation *op) { // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) @@ -699,7 +702,7 @@ for (const std::unique_ptr &step : options.postAnalysisSteps) { SmallVector newOps; - if (failed(step->run(funcOp, state, newOps))) + if (failed(step->run(op, state, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) @@ -708,16 +711,12 @@ // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { - annotateOpsWithBufferizationMarkers(funcOp, aliasInfo); + annotateOpsWithBufferizationMarkers(op, aliasInfo); return success(); } - // Bufferize all ops in funcOp. - OpBuilder b(funcOp.getContext()); - auto bufferizableOp = - dyn_cast(funcOp.getOperation()); - assert(bufferizableOp && "must use ModuleBufferization"); - if (failed(bufferizableOp.bufferize(b, state))) + // Bufferize op and all nested ops. + if (failed(bufferize(op, state))) return failure(); // Erase all obsolete ops. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -407,14 +407,14 @@ /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InitTensorEliminationStep::eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps) { - OpBuilder b(funcOp->getContext()); + OpBuilder b(op->getContext()); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - WalkResult status = funcOp->walk([&](Operation *op) { + WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Is this a matching OpOperand? if (!anchorMatchFunc(operand)) @@ -501,10 +501,10 @@ /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InsertSliceAnchoredInitTensorEliminationStep::run( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, SmallVector &newOps) { return eliminateInitTensors( - funcOp, state, + op, state, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -435,10 +435,10 @@ } // namespace mlir LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: - InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state, SmallVector &newOps) { auto &tensorState = getTensorBufferizationState(state); - funcOp.walk([&](InsertSliceOp insertSliceOp) { + op->walk([&](InsertSliceOp insertSliceOp) { // A copy of the source buffer is needed if either: // - The producer of `source` is not inplace. This is the case where a // slice is computed out of place into the inplace full tensor. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -29,10 +29,70 @@ using namespace mlir::linalg; using namespace mlir::linalg::comprehensive_bufferize; +static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, + MemRefType type, + ArrayRef dynShape) { + Value allocated = b.create( + loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); + return allocated; +} + namespace { +/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are +/// mostly identical. +template +struct BufferizePassHelper : public PassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + affine_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); + } + + void prepareOptions(BufferizationOptions &options) { + if (this->useAlloca) { + options.allocationFns->allocationFn = allocationFnUsingAlloca; + options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, + Value v) {}; + } + // TODO: Change to memref::CopyOp (default memCpyFn). + options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, + Value to) { + b.create(loc, from, to); + }; + + // Enable InitTensorOp elimination. + options.addPostAnalysisStep< + linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + // TODO: Find a way to enable this step automatically when bufferizing + // tensor dialect ops. + options.addPostAnalysisStep(); + + options.allowReturnMemref = this->allowReturnMemref; + options.allowUnknownOps = this->allowUnknownOps; + options.analysisFuzzerSeed = this->analysisFuzzerSeed; + options.testAnalysisOnly = this->testAnalysisOnly; + } + + void runCleanupPipeline(StringRef name, Operation *op) { + OpPassManager cleanupPipeline(name); + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + (void)this->runPipeline(cleanupPipeline, op); + } +}; + struct LinalgComprehensiveModuleBufferize - : public LinalgComprehensiveModuleBufferizeBase< - LinalgComprehensiveModuleBufferize> { + : public BufferizePassHelper> { LinalgComprehensiveModuleBufferize() {} LinalgComprehensiveModuleBufferize( @@ -41,61 +101,35 @@ void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - affine_ext::registerBufferizableOpInterfaceExternalModels(registry); - arith_ext::registerBufferizableOpInterfaceExternalModels(registry); - bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); - linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); - scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + BufferizePassHelper>::getDependentDialects(registry); + // Module bufferization requires additional external models for FuncOp etc. + registry.insert(); std_ext::registerBufferizableOpInterfaceExternalModels(registry); - tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); - vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; -} // end namespace -static void applyEnablingTransformations(ModuleOp moduleOp) { - RewritePatternSet patterns(moduleOp.getContext()); - patterns.add(moduleOp.getContext()); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); -} +struct LinalgComprehensiveFunctionBufferize + : public BufferizePassHelper> { + LinalgComprehensiveFunctionBufferize() {} -static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, - MemRefType type, - ArrayRef dynShape) { - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); - return allocated; + LinalgComprehensiveFunctionBufferize( + const LinalgComprehensiveFunctionBufferize &p) {} + + void runOnFunction() final; +}; +} // end namespace + +static void applyEnablingTransformations(Operation *op) { + RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); } void LinalgComprehensiveModuleBufferize::runOnOperation() { BufferizationOptions options; - if (useAlloca) { - options.allocationFns->allocationFn = allocationFnUsingAlloca; - options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, - Value v) {}; - } - // TODO: Change to memref::CopyOp (default memCpyFn). - options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, - Value to) { - b.create(loc, from, to); - }; - - options.allowReturnMemref = allowReturnMemref; - options.allowUnknownOps = allowUnknownOps; - options.analysisFuzzerSeed = analysisFuzzerSeed; - options.testAnalysisOnly = testAnalysisOnly; - - // Enable InitTensorOp elimination. - options.addPostAnalysisStep< - linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); - // TODO: Find a way to enable this step automatically when bufferizing tensor - // dialect ops. - options.addPostAnalysisStep(); + this->prepareOptions(options); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -108,13 +142,30 @@ if (options.testAnalysisOnly) return; - OpPassManager cleanupPipeline("builtin.module"); - cleanupPipeline.addPass(createCanonicalizerPass()); - cleanupPipeline.addPass(createCSEPass()); - cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); - (void)runPipeline(cleanupPipeline, moduleOp); + this->runCleanupPipeline("builtin.module", moduleOp); +} + +void LinalgComprehensiveFunctionBufferize::runOnFunction() { + BufferizationOptions options; + this->prepareOptions(options); + + applyEnablingTransformations(getOperation()); + + if (failed(runComprehensiveBufferize(getOperation(), options))) { + signalPassFailure(); + return; + } + + if (options.testAnalysisOnly) + return; + + this->runCleanupPipeline("builtin.func", getOperation()); } std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } + +std::unique_ptr mlir::createLinalgComprehensiveFunctionBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s + +// Run fuzzer with different seeds. +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null + +// CHECK-LABEL: func @use_tensor_func_arg( +// CHECK-SAME: %[[A:.*]]: tensor +func @use_tensor_func_arg(%A : tensor) -> (vector<4xf32>) { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]] + %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> + + // CHECK: return %[[res]] + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @return_tensor( +// CHECK-SAME: %[[A:.*]]: tensor +func @return_tensor(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = tensor.dim %[[A]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[A_memref]], %[[alloc]]) + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -916,3 +916,32 @@ } return %r : tensor } + +// ----- + +// CHECK-LABEL: func @func_without_tensor_args +func @func_without_tensor_args(%v : vector<10xf32>) -> () { + // CHECK: %[[alloc:.*]] = memref.alloc() + %0 = linalg.init_tensor[10] : tensor<10xf32> + + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32> + + %cst = arith.constant 0.0 : f32 + // CHECK: vector.transfer_read %[[alloc]] + %r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32> + + vector.print %r : vector<11xf32> + return +} + +// ----- + +// CHECK-LABEL: func private @private_func +func private @private_func(tensor) -> () + +// CHECK-LABEL: func @empty_func() +func @empty_func() -> () { + return +}