diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -67,7 +67,7 @@ /// Run the post analysis step. This function may modify the IR, but must keep /// `aliasInfo` (inside `state`) consistent. Newly created operations and /// operations that should be re-analyzed must be stored in `newOps`. - virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, + virtual LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) = 0; }; @@ -296,8 +296,8 @@ /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. struct BufferizationState { - BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options) - : aliasInfo(moduleOp), options(options) {} + BufferizationState(Operation *op, const BufferizationOptions &options) + : aliasInfo(op), options(options) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -19,10 +19,13 @@ struct BufferizationOptions; struct BufferizationState; +/// Bufferize the given op. +LogicalResult runComprehensiveBufferize(Operation *op, + const BufferizationOptions &options); + /// Bufferize the given function. Does not bufferize the function boundary. -// TODO: This function is meant to be called from ModuleBufferize and not can -// not yet be called standalone. -LogicalResult runComprehensiveBufferize(FuncOp funcOp, +/// Reuses an existing BufferizationState object. +LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, BufferizationState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -23,7 +23,7 @@ namespace linalg_ext { struct InitTensorEliminationStep : public PostAnalysisStep { - /// Try to eliminate InitTensorOps inside `funcOp`. + /// Try to eliminate InitTensorOps inside `op`. /// /// * `rewriteFunc` generates the replacement for the InitTensorOp. /// * Only InitTensorOps that are anchored on a matching OpOperand as per @@ -34,18 +34,18 @@ /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. LogicalResult eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps); }; -/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// Try to eliminate InitTensorOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -20,7 +20,7 @@ namespace tensor_ext { struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep { - LogicalResult run(FuncOp funcOp, BufferizationState &state, + LogicalResult run(Operation *op, BufferizationState &state, SmallVector &newOps) override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -56,11 +56,11 @@ /// This pass implements a cross-dialect bufferization approach and performs an /// analysis to determine which op operands and results may be bufferized in the -/// same buffers. The analysis is performed on topologically sorted CallOp and -/// FuncOp within a module. It provides analyses and bufferization across -/// function boundaries. Within a function boundary, the analysis is performed -/// on SSA use-def chains starting from function operands that are annotated -/// with the 'inplaceable' attribute. +/// same buffers. +std::unique_ptr createLinalgComprehensiveFunctionBufferizePass(); + +/// This pass is an extension of the comprehensive function bufferization. It +/// performs a function call analysis and bufferizes an entire module. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); /// Create a pass to convert Linalg operations which work on tensors to use diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -22,17 +22,43 @@ let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; } +def LinalgComprehensiveFunctionBufferize : + FunctionPass<"linalg-comprehensive-function-bufferize"> { + let summary = "Bufferize (tensor into memref) for a function."; + let description = [{ + This pass implements a cross-dialect bufferization approach and performs an + analysis to determine which op operands and results may be bufferized in the + same buffers. The analysis is performed on SSA use-def chains starting from + function operands that are annotated with the 'inplaceable' attribute. + }]; + let options = [ + Option<"testAnalysisOnly", "test-analysis-only", "bool", + /*default=*/"false", + "Only runs inplaceability analysis (for testing purposes only)">, + Option<"allowReturnMemref", "allow-return-memref", "bool", + /*default=*/"false", + "Allows the return of memrefs (for testing purposes only)">, + Option<"allowUnknownOps", "allow-unknown-ops", "bool", + /*default=*/"false", + "Allows unknown (not bufferizable) ops in the input IR.">, + Option<"useAlloca", "use-alloca", "bool", + /*default=*/"false", + "Use stack allocations for memrefs (for testing purposes only)">, + Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", + /*default=*/"0", + "Analyze ops in random order with a given seed (fuzzer)"> + ]; + let constructor = "mlir::createLinalgComprehensiveFunctionBufferizePass()"; +} + def LinalgComprehensiveModuleBufferize : Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> { let summary = "Bufferize (tensor into memref) for a Module."; let description = [{ - This pass implements a cross-dialect bufferization approach and performs an - analysis to determine which op operands and results may be bufferized in the - same buffers. The analysis is performed on topologically sorted CallOp and + This pass is an extension of linalg-comprehensive-function-bufferize. It + performs an inter-function analysis on topologically sorted CallOp and FuncOp within a module. It provides analyses and bufferization across - function boundaries. Within a function boundary, the analysis is performed - on SSA use-def chains starting from function operands that are annotated - with the 'inplaceable' attribute. + function boundaries. }]; let options = [ Option<"testAnalysisOnly", "test-analysis-only", "bool", diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -465,9 +465,11 @@ auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + bool isFuncOp = isa(op); // No tensor results or operands: Simply bufferize all nested ops. - if (!hasTensorResult && !hasTensorOperand) { + // Always pass through FuncOps. + if (!hasTensorResult && !hasTensorOperand && !isFuncOp) { for (Region ®ion : op->getRegions()) if (failed(bufferize(®ion, state))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -624,10 +624,10 @@ /// Assert that the current bufferization decisions are consistent. static LogicalResult -checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo, +checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, const BufferizationAliasInfo &aliasInfo) { Operation *inconsistentOp = nullptr; - WalkResult walkResult = funcOp.walk([&](Operation *op) { + WalkResult walkResult = op->walk([&](Operation *op) { if (auto bufferizableOp = dyn_cast(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { @@ -669,21 +669,27 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - FuncOp funcOp, const BufferizationOptions &options, + Operation *op, const BufferizationOptions &options) { + BufferizationState state(op, options); + return runComprehensiveBufferize(op, options, state); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, BufferizationState &state) { - DominanceInfo domInfo(funcOp); + DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - if (funcOp.body().empty()) - return success(); + // if (funcOp.body().empty()) + // return success(); - if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo))) return failure(); // Collect ops so we can build our own reverse traversal. SmallVector ops; - funcOp.walk([&](Operation *op) { + op->walk([&](Operation *op) { // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) @@ -699,7 +705,7 @@ for (const std::unique_ptr &step : options.postAnalysisSteps) { SmallVector newOps; - if (failed(step->run(funcOp, state, newOps))) + if (failed(step->run(op, state, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) @@ -708,16 +714,12 @@ // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { - annotateOpsWithBufferizationMarkers(funcOp, aliasInfo); + annotateOpsWithBufferizationMarkers(op, aliasInfo); return success(); } - // Bufferize all ops in funcOp. - OpBuilder b(funcOp.getContext()); - auto bufferizableOp = - dyn_cast(funcOp.getOperation()); - assert(bufferizableOp && "must use ModuleBufferization"); - if (failed(bufferizableOp.bufferize(b, state))) + // Bufferize op and all nested ops. + if (failed(bufferize(op, state))) return failure(); // Erase all obsolete ops. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -407,14 +407,14 @@ /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InitTensorEliminationStep::eliminateInitTensors( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps) { - OpBuilder b(funcOp->getContext()); + OpBuilder b(op->getContext()); BufferizationAliasInfo &aliasInfo = state.aliasInfo; - WalkResult status = funcOp->walk([&](Operation *op) { + WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Is this a matching OpOperand? if (!anchorMatchFunc(operand)) @@ -501,10 +501,10 @@ /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InsertSliceAnchoredInitTensorEliminationStep::run( - FuncOp funcOp, BufferizationState &state, + Operation *op, BufferizationState &state, SmallVector &newOps) { return eliminateInitTensors( - funcOp, state, + op, state, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -435,10 +435,10 @@ } // namespace mlir LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext:: - InplaceInsertSliceOpAnalysis::run(FuncOp funcOp, BufferizationState &state, + InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state, SmallVector &newOps) { auto &tensorState = getTensorBufferizationState(state); - funcOp.walk([&](InsertSliceOp insertSliceOp) { + op->walk([&](InsertSliceOp insertSliceOp) { // A copy of the source buffer is needed if either: // - The producer of `source` is not inplace. This is the case where a // slice is computed out of place into the inplace full tensor. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -29,39 +29,18 @@ using namespace mlir::linalg; using namespace mlir::linalg::comprehensive_bufferize; -namespace { -struct LinalgComprehensiveModuleBufferize - : public LinalgComprehensiveModuleBufferizeBase< - LinalgComprehensiveModuleBufferize> { - LinalgComprehensiveModuleBufferize() {} - - LinalgComprehensiveModuleBufferize( - const LinalgComprehensiveModuleBufferize &p) {} - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - affine_ext::registerBufferizableOpInterfaceExternalModels(registry); - arith_ext::registerBufferizableOpInterfaceExternalModels(registry); - bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); - linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); - scf_ext::registerBufferizableOpInterfaceExternalModels(registry); - std_ext::registerBufferizableOpInterfaceExternalModels(registry); - tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); - vector_ext::registerBufferizableOpInterfaceExternalModels(registry); - } -}; -} // end namespace - -static void applyEnablingTransformations(ModuleOp moduleOp) { - RewritePatternSet patterns(moduleOp.getContext()); - patterns.add(moduleOp.getContext()); - (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); +static void registerDependentDialects(DialectRegistry ®istry) { + registry.insert(); + affine_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); + linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } static Optional allocationFnUsingAlloca(OpBuilder &b, Location loc, @@ -72,8 +51,7 @@ return allocated; } -void LinalgComprehensiveModuleBufferize::runOnOperation() { - BufferizationOptions options; +static void prepareOptions(BufferizationOptions &options, bool useAlloca) { if (useAlloca) { options.allocationFns->allocationFn = allocationFnUsingAlloca; options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, @@ -85,17 +63,68 @@ b.create(loc, from, to); }; - options.allowReturnMemref = allowReturnMemref; - options.allowUnknownOps = allowUnknownOps; - options.analysisFuzzerSeed = analysisFuzzerSeed; - options.testAnalysisOnly = testAnalysisOnly; - // Enable InitTensorOp elimination. options.addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); // TODO: Find a way to enable this step automatically when bufferizing tensor // dialect ops. options.addPostAnalysisStep(); +} + +static void populateCleanupPipeline(OpPassManager &cleanupPipeline) { + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); +} + +namespace { +struct LinalgComprehensiveModuleBufferize + : public LinalgComprehensiveModuleBufferizeBase< + LinalgComprehensiveModuleBufferize> { + LinalgComprehensiveModuleBufferize() {} + + LinalgComprehensiveModuleBufferize( + const LinalgComprehensiveModuleBufferize &p) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registerDependentDialects(registry); + registry.insert(); + std_ext::registerBufferizableOpInterfaceExternalModels(registry); + } +}; + +struct LinalgComprehensiveFunctionBufferize + : public LinalgComprehensiveFunctionBufferizeBase< + LinalgComprehensiveFunctionBufferize> { + LinalgComprehensiveFunctionBufferize() {} + + LinalgComprehensiveFunctionBufferize( + const LinalgComprehensiveFunctionBufferize &p) {} + + void runOnFunction() final; + + void getDependentDialects(DialectRegistry ®istry) const override { + registerDependentDialects(registry); + } +}; +} // end namespace + +static void applyEnablingTransformations(Operation *op) { + RewritePatternSet patterns(op->getContext()); + patterns.add(op->getContext()); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); +} + +void LinalgComprehensiveModuleBufferize::runOnOperation() { + BufferizationOptions options; + prepareOptions(options, useAlloca); + + options.allowReturnMemref = allowReturnMemref; + options.allowUnknownOps = allowUnknownOps; + options.analysisFuzzerSeed = analysisFuzzerSeed; + options.testAnalysisOnly = testAnalysisOnly; ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); @@ -109,12 +138,38 @@ return; OpPassManager cleanupPipeline("builtin.module"); - cleanupPipeline.addPass(createCanonicalizerPass()); - cleanupPipeline.addPass(createCSEPass()); - cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + populateCleanupPipeline(cleanupPipeline); (void)runPipeline(cleanupPipeline, moduleOp); } +void LinalgComprehensiveFunctionBufferize::runOnFunction() { + BufferizationOptions options; + prepareOptions(options, useAlloca); + + options.allowReturnMemref = allowReturnMemref; + options.allowUnknownOps = allowUnknownOps; + options.analysisFuzzerSeed = analysisFuzzerSeed; + options.testAnalysisOnly = testAnalysisOnly; + + applyEnablingTransformations(getOperation()); + + if (failed(runComprehensiveBufferize(getOperation(), options))) { + signalPassFailure(); + return; + } + + if (options.testAnalysisOnly) + return; + + OpPassManager cleanupPipeline("builtin.func"); + populateCleanupPipeline(cleanupPipeline); + (void)runPipeline(cleanupPipeline, getOperation()); +} + std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } + +std::unique_ptr mlir::createLinalgComprehensiveFunctionBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -linalg-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s + +// Run fuzzer with different seeds. +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null + +// CHECK-LABEL: func @use_tensor_func_arg( +// CHECK-SAME: %[[A:.*]]: tensor +func @use_tensor_func_arg(%A : tensor) -> (vector<4xf32>) { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]] + %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> + + // CHECK: return %[[res]] + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @return_tensor( +// CHECK-SAME: %[[A:.*]]: tensor +func @return_tensor(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = tensor.dim %[[A]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[A_memref]], %[[alloc]]) + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK: return %[[res_tensor]] + return %0 : tensor +}