diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -197,68 +197,67 @@ return result; } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + // The callee was already bufferized, so we can directly take the type from + // its signature. + FunctionType funcType = funcOp.getFunctionType(); + return cast( + funcType.getResult(cast(value).getResultNumber())); + } + /// All function arguments are writable. It is the responsibility of the /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { func::CallOp callOp = cast(op); - unsigned numResults = callOp.getNumResults(); - unsigned numOperands = callOp->getNumOperands(); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - FunctionType funcType = funcOp.getFunctionType(); - - // Result types of the bufferized CallOp. - SmallVector resultTypes; - // Replacement values for the existing CallOp. These are usually the results - // of the bufferized CallOp, unless a tensor result folds onto an operand. - SmallVector replacementValues(numResults, Value()); - // For non-tensor results: A mapping from return val indices of the old - // CallOp to return val indices of the bufferized CallOp. - SmallVector> retValMapping(numResults, - std::nullopt); - // Operands of the bufferized CallOp. - SmallVector newOperands(numOperands, Value()); // 1. Compute the result types of the new CallOp. - for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { - unsigned returnValIdx = it.index(); - Type returnType = it.value(); + SmallVector resultTypes; + for (Value result : callOp.getResults()) { + Type returnType = result.getType(); if (!isa(returnType)) { // Non-tensor values are returned. - retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); continue; } // Returning a memref. - retValMapping[returnValIdx] = resultTypes.size(); - resultTypes.push_back(funcType.getResult(resultTypes.size())); + FailureOr resultType = + bufferization::getBufferType(result, options); + if (failed(resultType)) + return failure(); + resultTypes.push_back(*resultType); } - // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. - for (OpOperand &opOperand : callOp->getOpOperands()) { - unsigned idx = opOperand.getOperandNumber(); - Value tensorOperand = opOperand.get(); + // 2. Rewrite tensor operands as memrefs based on type of the already + // bufferized callee. + SmallVector newOperands; + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + FunctionType funcType = funcOp.getFunctionType(); + for (OpOperand &opOperand : callOp->getOpOperands()) { // Non-tensor operands are just copied. - if (!isa(tensorOperand.getType())) { - newOperands[idx] = tensorOperand; + if (!isa(opOperand.get().getType())) { + newOperands.push_back(opOperand.get()); continue; } // Retrieve buffers for tensor operands. - Value buffer = newOperands[idx]; - if (!buffer) { - FailureOr maybeBuffer = - getBuffer(rewriter, opOperand.get(), options); - if (failed(maybeBuffer)) - return failure(); - buffer = *maybeBuffer; - } + FailureOr maybeBuffer = + getBuffer(rewriter, opOperand.get(), options); + if (failed(maybeBuffer)) + return failure(); + Value buffer = *maybeBuffer; // Caller / callee type mismatch is handled with a CastOp. - auto memRefType = funcType.getInput(idx); + auto memRefType = funcType.getInput(opOperand.getOperandNumber()); // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. // If the memref type of the callee fails, introduce an extra memref.cast @@ -272,22 +271,16 @@ memRefType, buffer); buffer = castBuffer; } - newOperands[idx] = buffer; + newOperands.push_back(buffer); } // 3. Create the new CallOp. Operation *newCallOp = rewriter.create( callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); - // Get replacement values. - for (unsigned i = 0; i < replacementValues.size(); ++i) { - if (replacementValues[i]) - continue; - replacementValues[i] = newCallOp->getResult(*retValMapping[i]); - } // 4. Replace the old op with the new op. - replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); + replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults()); return success(); } @@ -326,6 +319,17 @@ struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto funcOp = cast(op); + auto bbArg = cast(value); + // Unstructured control flow is not supported. + assert(bbArg.getOwner() == &funcOp.getBody().front() && + "expected that block argument belongs to first block"); + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + } + /// Rewrite function bbArgs and return values into buffer form. This function /// bufferizes the function signature and the ReturnOp. When the entire /// function body has been bufferized, function return types can be switched @@ -384,9 +388,11 @@ bbArgUses.push_back(&use); // Change the bbArg type to memref. - Type memrefType = - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); - bbArg.setType(memrefType); + FailureOr memrefType = + bufferization::getBufferType(bbArg, options); + if (failed(memrefType)) + return failure(); + bbArg.setType(*memrefType); // Replace all uses of the original tensor bbArg. rewriter.setInsertionPointToStart(&frontBlock);