diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -323,6 +323,10 @@ /// read by themselves (e.g., ExtractSliceOp). bool isValueRead(Value value); + /// Return true if the buffer of the given tensor value is written to. May be + /// called only for values inside fully analyzed functions. + bool isValueWritten(Value value); + /// Starting from `value`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which /// `condition` evaluates to true. OpOperands of such matching Values are not diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -256,6 +256,7 @@ /// themselves (e.g., ExtractSliceOp). bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( Value value) { + assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); @@ -273,6 +274,29 @@ return false; } +bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueWritten( + Value value) { + assert(value.getType().isa() && "expected TensorType"); + SmallVector workingSet; + for (OpOperand &use : value.getUses()) + workingSet.push_back(&use); + + while (!workingSet.empty()) { + OpOperand *uMaybeWriting = workingSet.pop_back_val(); + if (!isInPlace(*uMaybeWriting)) + continue; + if (bufferizesToMemoryWrite(*uMaybeWriting)) + return true; + OpResult opResult = getAliasingOpResult(*uMaybeWriting); + if (!opResult) + continue; + for (OpOperand &use : getAliasingOpResult(*uMaybeWriting).getUses()) + workingSet.push_back(&use); + } + + return false; +} + // Starting from `value`, follow the use-def chain in reverse, always selecting // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -6,87 +6,68 @@ // //===----------------------------------------------------------------------===// // -// Module bufferization is an extension of Comprehensive Bufferize that +// Module Bufferization is an extension of Comprehensive Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp, along with a few helper -// functions that control the order in which functions are bufferized. +// implementations for FuncOp, CallOp and ReturnOp. // -// Three cases can occur during bufferization of FuncOps. +// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`. +// This function analyzed the given module and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // -// i. inplaceable function arguments may be reused in place after the -// function itself has been bufferized. This is encoded by IR resembling: +// After analyzing a FuncOp, additional information about its bbArgs is +// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`. // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// // ... uses of %0 -// %res = bufferization.to_tensor %0 : memref -// return %res : tensor -// } -// ``` +// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each +// tensor return value (if any). +// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is +// read/written. // -// this is the cue for the bufferization of the function foo (and calls -// to it) may bufferize to `func @foo(%A: memref)`. -// To fully achieve bufferization, an additional analysis is needed to -// determine whether function argument/operand pairs bufferize to a -// single inplace buffer argument (i.e. functions may return tensors in -// arbitrary order that may not match argument numbers). +// Only tensors that are equivalent to some FuncOp bbArg may be returned. +// Bufferization currently fails if other tensors (in particular tensors that +// bufferize out-of-place and result in a new buffer allocation) are returned. +// In the future, such allocations could be hoisted to the caller. // -// ii. results that don't map to an inplaceable function argument are -// generally allocated. Since memref semantics wrt ownership of the -// underlying memory region are not well-defined, comprehensive -// bufferization chooses to perform allocations in a scoped fashion: -// returning memrefs is always considered illegal. -// Such scenarios are encoded by IR resembling: +// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. +// ``` +// func @foo() -> tensor { +// %0 = linalg.init_tensor [...] : tensor +// return %0 : tensor +// } +// ``` // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// %1 = memref.dim %0, %c0 : memref -// %2 = memref.alloc(%1) : memref -// %3 = memref.cast %2 : memref to memref -// // ... uses of %3 -// memref.dealloc %2 : memref -// %res = bufferization.to_tensor %3 : memref -// return %res : tensor -// } -// ``` +// Module Bufferization implements the following calling convention. // -// this is the cue for the bufferization of the function foo (and calls -// to it) that it must bufferize to `func @foo(%A: memref, -// %B: memref)` (i.e. make a cloned -// allocation of the result tensor) -// To fully achieve bufferization, the alloc/dealloc pair must be lifted -// out of the function at each call site. +// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always +// be written to in-place. +// * If a tensor operand of a CallOp is read after the CallOp, the operand of +// the CallOp must bufferize out-of-place. // -// iii. as an optimization over ii., it may be possible to reuse an argument -// and only want to return a slice. -// This may forego allocation by letting *all* callers decide whether to -// pass a new *aliasing* memref function argument (i.e. a subview). -// Without loss of generality, callers may agree to allocate a new buffer -// to avoid this aliasing. Such scenarios are encoded by IR resembling: +// Example: The tensor.insert op bufferizes in-place because it is allowed to +// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize +// out-of-place because `%t0` is modified by the callee but read by the +// tensor.extract op. The analysis of CallOps decides whether an OpOperand must +// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`. +// ``` +// func @callee(%t1 : tensor) -> tensor { +// %f = ... : f32 +// %0 = tensor.insert %f into %t1[...] : tensor +// return %0 : tensor +// } // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%arg0: tensor {linalg.inplaceable = true}) -// -> tensor<4xf32> { -// %0 = bufferization.to_memref %arg0 : memref -// %1 = memref.subview %0[0] [4] [1] : memref to -// memref<4xf32, #map> -// // ... inplace computes into %1 -// %3 = bufferization.to_tensor %1 : memref<4xf32, #map> -// return %3 : tensor<4xf32> -// } -// ``` +// func @caller() -> () { +// %t0 = ... : tensor +// %1 = call @callee(%t0) : (tensor) -> (tensor) +// %2 = tensor.extract %1[...] : tensor +// } +// ``` // -// Note: In the future, it may be worthwhile to design special bufferization -// ops to encode the desired semantics at function boundaries for i., ii. and -// iii. +// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot +// analyze the function body. In such a case, the CallOp analysis conservatively +// assumes that each tensor OpOperand is both read and written. +// +// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked +// as "not reading" and/or "not writing". #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" @@ -113,8 +94,19 @@ /// indices. DenseMap> equivalentFuncArgs; + /// A set of all read BlockArguments of FuncOps. + DenseSet readBbArgs; + + /// A set of all written-to BlockArguments of FuncOps. + DenseSet writtenBbArgs; + + /// A set of all fully analyzed FuncOps. + DenseSet analyzedFuncOps; + + // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; + // A mapping of FuncOps to their callers. DenseMap> callerMap; }; } // namespace @@ -189,6 +181,42 @@ return success(); } }; + +/// Determine which FuncOp bbArgs are read and which are written. If this +/// PostAnalysisStep is run on a function with unknown ops, it will +/// conservatively assume that such ops bufferize to a read + write. +struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep { + LogicalResult run(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) override { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + + // Support only single return-terminated block in the function. + auto funcOp = cast(op); + + // If the function has no body, conservatively assume that all args are + // read + written. + if (funcOp.getBody().empty()) { + for (BlockArgument bbArg : funcOp.getArguments()) { + moduleState.readBbArgs.insert(bbArg); + moduleState.writtenBbArgs.insert(bbArg); + } + + return success(); + } + + for (BlockArgument bbArg : funcOp.getArguments()) { + if (!bbArg.getType().isa()) + continue; + if (state.isValueRead(bbArg)) + moduleState.readBbArgs.insert(bbArg); + if (state.isValueWritten(bbArg)) + moduleState.writtenBbArgs.insert(bbArg); + } + + return success(); + } +}; } // namespace static bool isaTensor(Type t) { return t.isa(); } @@ -588,26 +616,81 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - // CallOpInterface alone doesn't bufferize to a memory read, one of the uses - // of the matching bbArg may. It is the responsibility of the caller to - // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be - // conservative. - return true; + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + if (!moduleState.analyzedFuncOps.contains(funcOp)) + // FuncOp not analyzed yet. Assume that OpOperand is read. + return true; + + return moduleState.readBbArgs.contains( + funcOp.getArgument(opOperand.getOperandNumber())); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - return false; + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + if (!moduleState.analyzedFuncOps.contains(funcOp)) + // FuncOp not analyzed yet. Assume that OpOperand is written. + return true; + + return moduleState.writtenBbArgs.contains( + funcOp.getArgument(opOperand.getOperandNumber())); } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - // CallOpInterface is special, it needs to wait for the callee to be - // bufferized and needs to inspect the BufferAliasInfo object. It can't - // make a proper determination by itself and needs to be conservative. + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + if (!moduleState.equivalentFuncArgs.count(funcOp)) + // No equivalence info available for the called function. + return OpResult(); + + for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults(); + ++resultIdx) + if (moduleState.equivalentFuncArgs[funcOp][resultIdx] == + opOperand.getOperandNumber()) + return callOp->getOpResult(resultIdx); + + // Note: Returning a non-equivalent tensor from a FuncOp is currently not + // supported an will fail bufferization. return OpResult(); } + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + if (!moduleState.equivalentFuncArgs.count(funcOp)) + // No equivalence info available for the called function. + return {}; + + // Note: Returning a non-equivalent tensor from a FuncOp is currently not + // supported an will fail bufferization. + int64_t bbArgIdx = + moduleState.equivalentFuncArgs[funcOp][opResult.getResultNumber()]; + return {&op->getOpOperand(bbArgIdx)}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { + return BufferRelation::Equivalent; + } + /// In a first approximation, all the function arguments of a FuncOp are /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. @@ -662,11 +745,12 @@ getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { // Return operands that are equivalent to some bbArg, are not // returned. - Value buffer = - *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx), - /*forceInPlace=*/true); - replacementValues[returnValIdx] = buffer; - newOperands[*bbArgIdx] = buffer; + FailureOr bufferOrFailure = + state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); + if (failed(bufferOrFailure)) + return failure(); + replacementValues[returnValIdx] = *bufferOrFailure; + newOperands[*bbArgIdx] = *bufferOrFailure; continue; } @@ -696,11 +780,15 @@ // Retrieve buffers for tensor operands. Tensor operand buffers, who's // corresponding FuncOp bbArgs are equivalent to a returned tensor, were // already stored in `newOperands` during Step 1. - Value buffer = newOperands[idx] ? newOperands[idx] - : *state.getBuffer(rewriter, opOperand, - /*forceInPlace=*/true); + Value buffer = newOperands[idx]; + if (!buffer) { + FailureOr bufferOrFailure = state.getBuffer(rewriter, opOperand); + if (failed(bufferOrFailure)) + return failure(); + buffer = *bufferOrFailure; + } - // Caller / callee type mistmatch is handled with a CastOp. + // Caller / callee type mismatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. @@ -774,7 +862,6 @@ auto funcOp = cast(op); BlockArgument bbArg = value.dyn_cast(); assert(bbArg && "expected BlockArgument"); - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); // "linalg.inplaceable" overrides other writability decisions. This is // currently used for testing only. @@ -783,16 +870,8 @@ BufferizableOpInterface::kInplaceableAttrName)) return inplaceAttr.getValue(); - // In a first approximation: - // ========================= - // If the function is called, we can allocate on the caller side which lets - // us force inplace arguments at function boundaries. - // TODO: do not rely on this behavior. - if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end()) - return true; - - // All other function arguments are not writable. - return false; + // All function arguments are writable by default. + return true; } bool isAllocationHoistingBarrier(Operation *op) const { return true; } @@ -854,6 +933,8 @@ // Collect bbArg/return value information after the analysis. options->postAnalysisSteps.emplace_back( std::make_unique()); + options->postAnalysisSteps.emplace_back( + std::make_unique()); // Gather equivalence info for CallOps. equivalenceAnalysis(funcOp, aliasInfo, moduleState); @@ -862,6 +943,9 @@ if (failed(analyzeOp(funcOp, *options, state, aliasInfo))) return failure(); + // Mark op as fully analyzed. + moduleState.analyzedFuncOps.insert(funcOp); + // Add annotations to function arguments. if (options->testAnalysisOnly) annotateOpsWithBufferizationMarkers(funcOp, state); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -630,7 +630,7 @@ // of %r1 is read. // CHECK: scf.for // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: scf.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]} @@ -642,7 +642,7 @@ // %r1 bufferizes inplace fine. // CHECK: scf.for // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: scf.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} @@ -655,7 +655,7 @@ // of %r3 is read. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: linalg.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]} @@ -669,7 +669,7 @@ // %r3 bufferizes inplace fine. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: linalg.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -410,7 +410,9 @@ // CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> -// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: linalg.copy(%[[A]], %[[alloc]]) // CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () call @some_external_func(%A) : (tensor<4xi32>) -> () @@ -430,7 +432,9 @@ // CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> -// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: linalg.copy(%[[A]], %[[alloc]]) // CHECK: call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () scf.execute_region { call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> () @@ -488,16 +492,19 @@ %lb : index, %ub : index, %step : index) -> (tensor, tensor) { -// CHECK-NEXT: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]] +// CHECK: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]] %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) : (tensor, tensor, tensor<4xf32>, index, index, index) -> (tensor, tensor) - // %r0#0 is actually %B after inplaceable results are swapped in the callee. -// CHECK-NEXT: call @some_external_func(%[[B]]) : (memref) -> () + // %r0#0 requires a copy because we have no idea what the function is doing. +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] +// CHECK: linalg.copy(%[[B]], %[[alloc]]) +// CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref) -> () call @some_external_func(%r0#0) : (tensor) -> () -// CHECK-NEXT: return +// CHECK: return return %r0#0, %r0#1: tensor, tensor } @@ -710,8 +717,16 @@ func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, %C : tensor {linalg.inplaceable = false}) { -// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref to memref -// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]]) +// CHECK: %[[ALLOC_C:.*]] = memref.alloc +// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] +// CHECK: %[[ALLOC_B:.*]] = memref.alloc +// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK: %[[ALLOC_A:.*]] = memref.alloc +// CHECK: linalg.copy(%[[A]], %[[ALLOC_A]]) +// CHECK: linalg.copy(%[[B]], %[[ALLOC_B]]) +// CHECK: linalg.copy(%[[C]], %[[ALLOC_C]]) +// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return } @@ -1022,7 +1037,10 @@ %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { // TODO: There should be a memory copy here. This is a bug in CallOp // bufferization. - // CHECK: call @inner_func_2(%[[arg0]]) + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[arg0]], %[[alloc]]) + // CHECK: call @inner_func_2(%[[casted]]) %3 = call @inner_func_2(%t1) : (tensor) -> tensor scf.yield %t1 : tensor }