diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -237,6 +237,12 @@ /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { + if (auto funcOp = dyn_cast(op)) { + bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); + return hasTensorArg || hasTensorResult; + } + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); return hasTensorResult || hasTensorOperand; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -337,16 +337,6 @@ } } // namespace -static bool isaTensor(Type t) { return t.isa(); } - -/// If `value` is a memref::CastOp, return its source. Otherwise, return -/// `value` directly. -static Value getNonCastedValue(Value value) { - while (auto castOp = value.getDefiningOp()) - value = castOp.source(); - return value; -} - /// Remove the attribute that triggers inplace bufferization on a FuncOp /// argument `bbArg`. static void removeBufferizationFuncArguments(BlockArgument bbArg) { @@ -366,26 +356,15 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -/// Return the FunctionType with `argumentTypes` and `resultTypes` where each -/// tensor is replaced by the corresponding buffer type. -/// In order for all the callers to agree, this *must* bufferize to the most -/// dynamic buffer type supported. -/// A later pass across all CallOps in the module can decide whether to simplify -/// the types of to version according to some cost model. -static FunctionType -getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes, - TypeRange resultTypes, - const BufferizationOptions &options) { - auto rewrite = [&](Type t) -> Type { - // TODO: non-zero address space. - // TODO: layout information if relevant. - if (auto tensorType = t.dyn_cast()) - return getMemRefType(tensorType, options); - return t; - }; - auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); - auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite)); - return FunctionType::get(ctx, argTypes, retTypes); +/// Return the index-th bufferized function argument type. This assumes that the +/// specified argument is a tensor. +static BaseMemRefType +getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, + const BufferizationOptions &options) { + auto tensorType = + funcOp.getFunctionType().getInput(index).dyn_cast(); + assert(tensorType && "expected TensorType"); + return getMemRefType(tensorType, options); } /// Gather equivalence info of CallOps. @@ -415,150 +394,6 @@ }); } -/// Rewrite the `funcOp` arguments analysis return values and terminator into -/// buffer form (using the canonical memref layout for now), according to the -/// inPlace-bufferizable information of the function arguments. -/// -/// This relies on a buffer equivalence analysis of each return operand. When a -/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be -/// dropped from the return values and becomes inplaceable at all callers. This -/// assumes all CallOp perform the necessary work to clone operands so as to -/// make them inplaceable. Reliance on this logic will need to be relaxed in the -/// future. -/// -/// Note: Returning a memref currently fails bufferization. If such memrefs -/// originate from an op with an Alloc effect, they could be hoisted in the -/// future. -static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, - RewriterBase &rewriter, - BufferizationState &state) { - const FuncAnalysisState &funcState = - getFuncAnalysisState(state.getAnalysisState()); - - // If nothing to do then we are done. - if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) && - !llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor)) - return success(); - - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - // TODO: Atm we have 3 cases: - // 1. if a function is called from within the Module, it must have bufferized - // to inplaceable tensor results. - // 2. if it is bodiless, it must have bufferized and is not allowed to have - // result tensors. - // 3. if it is not called internally, it still must bufferize to inplaceable - // tensor results and we construct it now (e.g. top-level function called - // externally). - // -> Figure out a better layering. - TypeRange resultTypes; - - // Corner case: Bodiless FuncOp - // ============================ - // The body of such functions is assumed opaque and we can't know the - // bufferization contract they want to enforce atm. - // As a consequence, only support functions that don't return any tensor atm. - if (funcOp.getBody().empty()) { - if (llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor)) - return funcOp->emitError() << "cannot bufferize bodiless function that " - << "returns a tensor"; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getFunctionType().getInputs(), - funcOp.getFunctionType().getResults(), state.getOptions()); - funcOp.setType(bufferizedFuncType); - return success(); - } - - // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // 1. For each FuncOp result, keep track of which inplace argument it reuses. - SmallVector returnValues; - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Value returnVal = returnOperand.get(); - - // If not a renturn tensor type just forward it. - if (!returnVal.getType().isa()) { - returnValues.push_back(returnVal); - continue; - } - - // If return operand is equivalent to some bbArg, no need to return it. - auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp); - if (funcOpIt != funcState.equivalentFuncArgs.end() && - funcOpIt->second.count(returnOperand.getOperandNumber())) - continue; - - // Cast values at the call site if necessary. - returnValues.push_back( - getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); - } - - // 2. Rewrite the terminator without the inPlace bufferizable values. - ValueRange retValues{returnValues}; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getFunctionType().getInputs(), - retValues.getTypes(), state.getOptions()); - OpBuilder b(returnOp); - b.create(returnOp.getLoc(), returnValues); - returnOp->erase(); - - // 3. Rewrite the bbArgs. - // Iterate on the original `numArgs` and replace them in order. - // This guarantees the argument order still matches after the rewrite. - Block &frontBlock = funcOp.getBody().front(); - unsigned numArgs = frontBlock.getNumArguments(); - for (unsigned idx = 0; idx < numArgs; ++idx) { - auto bbArg = frontBlock.getArgument(0); - auto tensorType = bbArg.getType().dyn_cast(); - // Non-tensor types are just forwarded. - if (!tensorType) { - frontBlock.addArgument(bbArg.getType(), bbArg.getLoc()); - bbArg.replaceAllUsesWith(frontBlock.getArguments().back()); - frontBlock.eraseArgument(0); - continue; - } - - // Get the buffer type from the bufferized function type. - Type memrefType = bufferizedFuncType.getInput(idx); - Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc()); - OpBuilder b(funcOp->getContext()); - b.setInsertionPointToStart(&frontBlock); - // Replace all uses of bbArg through a ToMemRefOp. - for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { - if (auto toMemrefOp = - dyn_cast(use.getOwner())) { - if (memref.getType() != toMemrefOp.memref().getType()) { - // Type has changed, insert a cast. - assert(memref::CastOp::areCastCompatible( - memref.getType(), toMemrefOp.memref().getType()) && - "bufferizeFuncOpBoundary: cast incompatible"); - auto castOp = b.create( - funcOp.getLoc(), toMemrefOp.memref().getType(), memref); - toMemrefOp.memref().replaceAllUsesWith(castOp); - } else { - // Type did not change, replace directly. - toMemrefOp.memref().replaceAllUsesWith(memref); - } - } - } - // Replace all remaining uses by a to_tensor. - if (!bbArg.use_empty()) { - auto toTensorOp = - b.create(funcOp.getLoc(), memref); - bbArg.replaceAllUsesWith(toTensorOp); - } - frontBlock.eraseArgument(0); - // TODO: add support to erase aliasInfo entries if deemed necessary. - } - - // 4. Rewrite the FuncOp type to buffer form. - funcOp.setType(bufferizedFuncType); - - return success(); -} - /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. @@ -826,9 +661,8 @@ return BufferRelation::Equivalent; } - /// In a first approximation, all the function arguments of a FuncOp are - /// marked inplaceable. For now, it is the responsibility of the `callOp` - /// bufferization to allow FuncOp that are inplaceable to write inPlace. + /// All function arguments are writable. It is the responsibility of the + /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { func::CallOp callOp = cast(op); @@ -871,7 +705,7 @@ for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!isaTensor(returnType)) { + if (!returnType.isa()) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -903,12 +737,10 @@ funcOp.getFunctionType().getResult(resultTypes.size())); } - // 2. Compute bufferized FunctionType. - SmallVector argumentTypes{callOp->getOperandTypes()}; - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), argumentTypes, resultTypes, options); + // 2. Get the bufferized FunctionType of the called function. Recursive or + // circular call graphs are not currently supported, so we can be sure that + // the called function was already bufferized. + FunctionType bufferizedFuncType = funcOp.getFunctionType(); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) { @@ -993,15 +825,136 @@ assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); #endif // NDEBUG + + // ReturnOps are bufferized as part of FuncOps. return failure(); } }; struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { + /// Rewrite function bbArgs and return values into buffer form (using the + /// canonical memref layout for now). This function bufferizes the function + /// signature and the ReturnOp. When the entire function body has been + /// bufferized, function return types can be switched to more concise memref + /// types as part of `foldMemRefCasts`. + /// + /// When a tensor function argument is known to be equivalent to a tensor + /// result, it is dropped from the return values. + /// + /// All function bbArgs are writable unless they are explicitly marked as + /// read-only. Callers must insert copies when needed. + /// + /// Note: Returning a memref is possible, but corresponding CallOp + /// bufferizations fail unless `allowReturnAllocs`. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - return failure(); + auto funcOp = cast(op); + FunctionType funcType = funcOp.getFunctionType(); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); + const BufferizationOptions &options = state.getOptions(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcType.getInputs())) { + Type argType = it.value(); + if (auto tensorType = argType.dyn_cast()) { + argTypes.push_back( + getBufferizedFunctionArgType(funcOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + // Bodiless functions are assumed opaque and we cannot know the + // bufferization contract they want to enforce. As a consequence, only + // support functions that don't return any tensors atm. + if (funcOp.getBody().empty()) { + SmallVector retTypes; + for (Type resultType : funcType.getResults()) { + if (resultType.isa()) + return funcOp->emitError() << "cannot bufferize bodiless function " + << "that returns a tensor"; + retTypes.push_back(resultType); + } + funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); + return success(); + } + + // TODO: Support functions with multiple returns. + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. + Block &frontBlock = funcOp.getBody().front(); + for (BlockArgument &bbArg : frontBlock.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + // Non-tensor types stay the same. + if (!tensorType) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + Type memrefType = + getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + bbArg.setType(memrefType); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(&frontBlock); + if (!bbArgUses.empty()) { + // Insert to_tensor because the remaining function body has not been + // bufferized yet. + Value toTensorOp = + rewriter.create(funcOp.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + + // If not a tensor type just forward it. + if (!returnVal.getType().isa()) { + returnValues.push_back(returnVal); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( + funcOp, moduleState, returnOperand.getOperandNumber())) { + rewriter.setInsertionPoint(returnOp); + Location loc = returnOp.getLoc(); + Value toMemrefOp = rewriter.create( + loc, getMemRefType(returnVal.getType().cast(), options), + returnVal); + BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); + // Note: This copy will fold away. It must be inserted here to ensure + // that `returnVal` still has at least one use and does not fold away. + if (failed( + createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + return funcOp->emitError("could not generate copy for bbArg"); + continue; + } + + returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); + } + + // 3. Rewrite the terminator without the in-place bufferizable values. + returnOp.operandsMutable().assign(returnValues); + + // 4. Rewrite the FuncOp type to buffer form. + funcOp.setType(FunctionType::get(op->getContext(), argTypes, + ValueRange(returnValues).getTypes())); + + return success(); } /// Return `true` if the given function argument is writable. @@ -1057,6 +1010,34 @@ setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); } +/// Fold return values that are memref casts and update function return types. +/// +/// During FuncOp bufferization, the exact type of the returned memrefs (if any) +/// is not known yet. Therefore, the bufferization uses memref types with the +/// most generic layout map as function return types. After bufferizing the +/// entire function body, a more concise memref type can potentially be used for +/// the return type of the function. +static void foldMemRefCasts(FuncOp funcOp) { + if (funcOp.getBody().empty()) + return; + + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + SmallVector resultTypes; + + for (OpOperand &operand : returnOp->getOpOperands()) { + if (auto castOp = operand.get().getDefiningOp()) { + operand.set(castOp.source()); + resultTypes.push_back(castOp.source().getType()); + } else { + resultTypes.push_back(operand.get().getType()); + } + } + + auto newFuncType = FunctionType::get( + funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); + funcOp.setType(newFuncType); +} + LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( ModuleOp moduleOp, OneShotBufferizationOptions options) { IRRewriter rewriter(moduleOp.getContext()); @@ -1107,15 +1088,11 @@ // Bufferize functions. for (FuncOp funcOp : orderedFuncOps) { - // No body => no analysis. - if (!funcOp.getBody().empty()) - if (failed(bufferizeOp(funcOp, bufferizationState))) - return failure(); - // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState))) + if (failed(bufferizeOp(funcOp, bufferizationState))) return failure(); + foldMemRefCasts(funcOp); } // Check result. diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -11,6 +11,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo() -> tensor @@ -212,6 +213,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo(%t : tensor) -> (f32, tensor, f32)