diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -490,17 +490,16 @@ namespace comprehensive_bufferize { namespace std_ext { -/// Return the index of the parent function's bbArg that is equivalent to the -/// given ReturnOp operand (if any). +/// Return the index of the bbArg in the given FuncOp that is equivalent to the +/// specified return value (if any). static Optional -getEquivalentFuncArgIdx(ModuleBufferizationState &state, - OpOperand &returnOperand) { - FuncOp funcOp = cast(returnOperand.getOwner()->getParentOp()); - if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber())) +getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state, + int64_t returnValIdx) { + if (!state.equivalentFuncArgs[funcOp].count(returnValIdx)) // Return value has no equivalent bbArg. return None; - return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; + return state.equivalentFuncArgs[funcOp][returnValIdx]; } struct CallOpInterface @@ -529,6 +528,7 @@ BufferizationState &state) const { CallOp callOp = cast(op); unsigned numResults = callOp.getNumResults(); + unsigned numOperands = callOp->getNumOperands(); FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && "expected CallOp to a FuncOp"); @@ -542,54 +542,48 @@ // For non-tensor results: A mapping from return val indices of the old // CallOp to return val indices of the bufferized CallOp. SmallVector> retValMapping(numResults, None); - - if (funcOp.body().empty()) { - // The callee is bodiless / external, so we cannot inspect it and we - // cannot assume anything. We can just assert that it does not return a - // tensor as this would have to bufferize to "return a memref", whose - // semantics is ill-defined. - for (int i = 0; i < numResults; ++i) { - Type returnType = callOp.getResult(i).getType(); - if (isaTensor(returnType)) - return callOp->emitError() - << "cannot bufferize bodiless function that returns a tensor"; + // Operands of the bufferized CallOp. + SmallVector newOperands(numOperands, Value()); + + // Based on previously gathered equivalence information, we know if a + // tensor result folds onto an operand. These are the only tensor value + // results that are supported at the moment. + // + // For tensors return values that do not fold onto an operand, additional + // work is needed (TODO) to either: + // * hoist a result into an inplaceable operand or + // * devise a better representation to truly return a buffer. + // + // Note: If a function has no body, no equivalence information is + // available. Consequently, a tensor return value cannot be proven to fold + // onto a FuncOp bbArg, so calls to such functions are not bufferizable at + // the moment. + + // 1. Compute the result types of the new CallOp. Tensor results that are + // equivalent to a FuncOp bbArg are no longer returned. + for (auto it : llvm::enumerate(callOp.getResultTypes())) { + unsigned returnValIdx = it.index(); + Type returnType = it.value(); + if (!isaTensor(returnType)) { + // Non-tensor values are returned. + retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); - retValMapping[i] = i; + continue; } - } else { - // The callee has a body. Based on previously gathered equivalence - // information, we know if a tensor result folds onto an operand. These - // are the only tensor value returns that are supported at the moment. - // - // For tensors return values that do not fold onto an operand, additional - // work is needed (TODO) to either: - // * hoist a result into an inplaceable operand or - // * devise a better representation to truly return a buffer. - ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // For each FuncOp result, keep track of which inplace argument it reuses. - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - unsigned returnIdx = returnOperand.getOperandNumber(); - Type returnType = returnOperand.get().getType(); - if (!isaTensor(returnType)) { - // Non-tensor values are returned. - retValMapping[returnIdx] = resultTypes.size(); - resultTypes.push_back(returnType); - continue; - } - - if (Optional bbArgIdx = - getEquivalentFuncArgIdx(moduleState, returnOperand)) { - // Return operands that are equivalent to some bbArg, are not - // returned. - replacementValues[returnIdx] = - state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); - continue; - } - - llvm_unreachable("returning non-equivalent tensors not supported"); + + if (Optional bbArgIdx = + getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { + // Return operands that are equivalent to some bbArg, are not + // returned. + Value buffer = + state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); + replacementValues[returnValIdx] = buffer; + newOperands[*bbArgIdx] = buffer; + continue; } + + return callOp->emitError( + "call to FuncOp that returns non-equivalent tensors not supported"); } // 2. Compute bufferized FunctionType. @@ -601,23 +595,26 @@ moduleState.bufferizedFunctionTypes); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. - SmallVector newOperands; - newOperands.reserve(callOp->getNumOperands()); for (OpOperand &opOperand : callOp->getOpOperands()) { + unsigned idx = opOperand.getOperandNumber(); Value tensorOperand = opOperand.get(); + // Non-tensor operands are just copied. if (!tensorOperand.getType().isa()) { - newOperands.push_back(tensorOperand); + newOperands[idx] = tensorOperand; continue; } - // Tensor operands are guaranteed to have been buferized. - int64_t idx = opOperand.getOperandNumber(); - Value buffer = state.lookupBuffer(rewriter, tensorOperand); + // Retrieve buffers for tensor operands. Tensor operand buffers, who's + // corresponding FuncOp bbArgs are equivalent to a returned tensor, were + // already stored in `newOperands` during Step 1. + Value buffer = newOperands[idx] + ? newOperands[idx] + : state.lookupBuffer(rewriter, tensorOperand); // Caller / callee type mistmatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); - // Since we don't yet have a clear layout story, buffer_cast may + // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. // If the memref type of the callee fails, introduce an extra memref.cast // that will either canonicalize away or fail compilation until we can do @@ -627,20 +624,21 @@ memRefType, buffer); buffer = castBuffer; } - newOperands.push_back(buffer); + newOperands[idx] = buffer; } // 4. Create the new CallOp. Operation *newCallOp = rewriter.create( callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); - - // 5. Replace the old op with the new op. + // Get replacement values for non-tensor / non-equivalent results. for (int i = 0; i < replacementValues.size(); ++i) { if (replacementValues[i]) continue; replacementValues[i] = newCallOp->getResult(*retValMapping[i]); } + + // 5. Replace the old op with the new op. state.replaceOp(rewriter, callOp, replacementValues); return success(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -187,3 +187,26 @@ return %r1, %r2 : vector<5xf32>, vector<5xf32> } + +// ----- + +func private @foo(%t : tensor) -> (f32, tensor, f32) + +func @call_to_unknown_tensor_returning_func(%t : tensor) { + // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}} + call @foo(%t) : (tensor) -> (f32, tensor, f32) + return +} + +// ----- + +func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) { + %0 = linalg.init_tensor [5] : tensor<5xf32> + return %0 : tensor<5xf32> +} + +func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) { + // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}} + call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + return +}