diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -490,6 +490,19 @@ namespace comprehensive_bufferize { namespace std_ext { +/// Return the index of the parent function's bbArg that is equivalent to the +/// given ReturnOp operand (if any). +static Optional +getEquivalentFuncArgIdx(ModuleBufferizationState &state, + OpOperand &returnOperand) { + FuncOp funcOp = cast(returnOperand.getOwner()->getParentOp()); + if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber())) + // Return value has no equivalent bbArg. + return None; + + return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; +} + struct CallOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, @@ -515,57 +528,67 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { CallOp callOp = cast(op); + unsigned numResults = callOp.getNumResults(); FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && - "expected Callop to a FuncOp"); + "expected CallOp to a FuncOp"); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - // 1. Filter return types: - // - if the callee is bodiless / external, we cannot inspect it and we - // cannot assume anything. We can just assert that it does not return a - // tensor as this would have to bufferize to "return a memref", whose - // semantics is ill-defined. - // - if the callee has a body, we perform inter-procedural equivalence - // analysis. When successful, a result folds onto an operand. When - // unsuccessful, additional work is needed (TODO) to either: - // * hoist a result into an inplaceable operand or - // * devise a better representation to truly return a buffer. + // Result types of the bufferized CallOp. SmallVector resultTypes; + // Replacement values for the existing CallOp. These are usually the results + // of the bufferized CallOp, unless a tensor result folds onto an operand. + SmallVector replacementValues(numResults, Value()); + // For non-tensor results: A mapping from return val indices of the old + // CallOp to return val indices of the bufferized CallOp. + SmallVector> retValMapping(numResults, None); + if (funcOp.body().empty()) { - if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) - return callOp->emitError() - << "cannot bufferize bodiless function that returns a tensor"; + // The callee is bodiless / external, so we cannot inspect it and we + // cannot assume anything. We can just assert that it does not return a + // tensor as this would have to bufferize to "return a memref", whose + // semantics is ill-defined. + for (int i = 0; i < numResults; ++i) { + Type returnType = callOp.getResult(i).getType(); + if (isaTensor(returnType)) + return callOp->emitError() + << "cannot bufferize bodiless function that returns a tensor"; + resultTypes.push_back(returnType); + retValMapping[i] = i; + } } else { + // The callee has a body. Based on previously gathered equivalence + // information, we know if a tensor result folds onto an operand. These + // are the only tensor value returns that are supported at the moment. + // + // For tensors return values that do not fold onto an operand, additional + // work is needed (TODO) to either: + // * hoist a result into an inplaceable operand or + // * devise a better representation to truly return a buffer. ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); // For each FuncOp result, keep track of which inplace argument it reuses. for (OpOperand &returnOperand : returnOp->getOpOperands()) { + unsigned returnIdx = returnOperand.getOperandNumber(); Type returnType = returnOperand.get().getType(); if (!isaTensor(returnType)) { + // Non-tensor values are returned. + retValMapping[returnIdx] = resultTypes.size(); resultTypes.push_back(returnType); continue; } - // If return operand is equivalent to some bbArg, no need to return it. - if (moduleState.equivalentFuncArgs[funcOp].count( - returnOperand.getOperandNumber())) { - int64_t idx = - moduleState - .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; - Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); - Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx)); - // Add a ToTensorOp to kill all uses of the CallOp return. - // Replace all uses of the CallOp results so we can erase the CallOp. - // This ToTensorOp must fold/DCE away or bufferization should be - // considered failed. - Value toTensorOp = rewriter.create( - callOp.getLoc(), buffer); - oldRes.replaceAllUsesWith(toTensorOp); + if (Optional bbArgIdx = + getEquivalentFuncArgIdx(moduleState, returnOperand)) { + // Return operands that are equivalent to some bbArg, are not + // returned. + replacementValues[returnIdx] = + state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); continue; } - resultTypes.push_back(returnType); + llvm_unreachable("returning non-equivalent tensors not supported"); } } @@ -612,8 +635,13 @@ callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); - // 5. Delete the op at the end of bufferization. - callOp->erase(); + // 5. Replace the old op with the new op. + for (int i = 0; i < replacementValues.size(); ++i) { + if (replacementValues[i]) + continue; + replacementValues[i] = newCallOp->getResult(*retValMapping[i]); + } + state.replaceOp(rewriter, callOp, replacementValues); return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1000,6 +1000,34 @@ // ----- +// CHECK-LABEL: func @inner_func( +// CHECK-SAME: %[[arg0:.*]]: memref) -> (tensor, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor + // CHECK: return %[[load]] : f32 + return %0, %1 : tensor, f32 +} + +// CHECK-LABEL: func @call_func_with_non_tensor_return( +// CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}) -> (f32, tensor) { + // CHECK-NOT: copy + // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor) -> (tensor, f32) + // CHECK: return %[[call]] : f32 + return %1, %0 : f32, tensor +} + +// ----- + // CHECK-LABEL: func @func_without_tensor_args func @func_without_tensor_args(%v : vector<10xf32>) -> () { // CHECK: %[[alloc:.*]] = memref.alloc()