diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -258,7 +258,8 @@ // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(func::FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - FuncAnalysisState &funcState) { + OneShotAnalysisState &state) { + FuncAnalysisState &funcState = getFuncAnalysisState(state); funcOp->walk([&](func::CallOp callOp) { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); @@ -270,6 +271,8 @@ for (auto it : funcState.equivalentFuncArgs[calledFunction]) { int64_t returnIdx = it.first; int64_t bbargIdx = it.second; + if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) + continue; Value returnVal = callOp.getResult(returnIdx); Value argVal = callOp->getOperand(bbargIdx); aliasInfo.unionEquivalenceClasses(returnVal, argVal); @@ -409,7 +412,7 @@ funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, funcState); + equivalenceAnalysis(funcOp, aliasInfo, state); // Analyze funcOp. if (failed(analyzeOp(funcOp, state))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -196,8 +196,9 @@ // CHECK: %[[call:.*]] = call @inner_func(%[[casted]]) %0, %1 = call @inner_func(%t0) : (tensor) -> (tensor, f32) - // Note: The tensor return value has folded away. - // CHECK: return %[[call]] : f32 + // Note: The tensor return value cannot fold away because the CallOp + // bufferized out-of-place. + // CHECK: return %[[call]], %[[alloc]] : f32, memref return %1, %0 : f32, tensor }