diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Module Bufferization is an extension of Comprehensive Bufferize that +// Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // @@ -357,14 +357,27 @@ } /// Return the index-th bufferized function argument type. This assumes that the -/// specified argument is a tensor. +/// specified argument is a tensor. If the tensor is ranked, a layout map may be +/// specified by the user. If no layout map is specified, a fully dynamic map is +/// used. static BaseMemRefType getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = funcOp.getFunctionType().getInput(index).dyn_cast(); assert(tensorType && "expected TensorType"); - return getMemRefType(tensorType, options); + BaseMemRefType memrefType = getMemRefType(tensorType, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizableOpInterface::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = memrefType.dyn_cast(); + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); + return MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); } /// Gather equivalence info of CallOps. @@ -451,103 +464,6 @@ return success(); } -static void foreachCaller(const FuncCallerMap &callerMap, func::FuncOp callee, - llvm::function_ref doit) { - auto itCallers = callerMap.find(callee); - if (itCallers == callerMap.end()) - return; - for (Operation *caller : itCallers->second) - doit(caller); -} - -/// Postprocess the linalg.buffer_layout annotation across function boundaries. -/// This is a purely mechanical process that may later become part of a -/// separate pass with its own layout assignment heuristic. -static void layoutPostProcessing(ModuleOp moduleOp) { - SmallVector orderedFuncOps; - DenseMap> callerMap; - auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); - (void)res; - assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); - - for (func::FuncOp funcOp : orderedFuncOps) { - DenseMap> operandsPerCaller; - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.try_emplace(caller, SmallVector()); - }); - - SmallVector argumentTypes; - // Iterate on each function argument and check it it was marked with a - // desired layout. - for (const auto &it : - llvm::enumerate(funcOp.getFunctionType().getInputs())) { - int argNumber = it.index(); - Type inputType = it.value(); - auto memrefType = inputType.dyn_cast(); - auto layoutAttr = funcOp.getArgAttrOfType( - argNumber, BufferizableOpInterface::kBufferLayoutAttrName); - AffineMap desiredLayoutMap = - layoutAttr ? layoutAttr.getValue() : AffineMap(); - AffineMap currentLayoutMap = - memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); - if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { - argumentTypes.push_back(inputType); - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.find(caller)->getSecond().push_back( - caller->getOperand(argNumber)); - }); - continue; - } - - // Compute the buffer type with desired layout and add to input argument - // types. - MemRefType desiredMemrefType = MemRefType::get( - memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); - argumentTypes.push_back(desiredMemrefType); - - // If funcOp's body is not empty, change the bbArg type and propagate. - if (!funcOp.getBody().empty()) { - BlockArgument bbArg = funcOp.getArgument(argNumber); - bbArg.setType(desiredMemrefType); - OpBuilder b(bbArg.getContext()); - b.setInsertionPointToStart(bbArg.getOwner()); - assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) && - "layoutPostProcessing: cast incompatible"); - // Cast back to the original memrefType and let it canonicalize. - Value cast = - b.create(funcOp.getLoc(), memrefType, bbArg); - bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); - } - - // Cast to desired buffer type on all callers to `funcOp`. - // TODO: on the callee side, this may even have to trigger a copy to - // change the layout. For now let the memref::CastOp fail to verify in - // such cases. - auto castArg = [&](Operation *caller) { - OpBuilder b(caller); - assert( - memref::CastOp::areCastCompatible( - caller->getOperand(argNumber).getType(), desiredMemrefType) && - "layoutPostProcessing.2: cast incompatible"); - Value newOperand = b.create( - funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); - operandsPerCaller.find(caller)->getSecond().push_back(newOperand); - }; - foreachCaller(callerMap, funcOp, castArg); - } - - // Set operands with cast buffer on all callers to `funcOp`. - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - caller->setOperands(operandsPerCaller.lookup(caller)); - }); - - // Finally set the funcOp type to update the arguments. - auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, - funcOp.getFunctionType().getResults()); - funcOp.setType(newFuncType); - } -} - namespace mlir { namespace linalg { namespace comprehensive_bufferize { @@ -1111,10 +1027,6 @@ if (failed(finalizeBuffers(moduleOp, options))) return failure(); - // Perform a post-processing pass of layout modification at function boundary - // according to the kBufferLayoutAttrName. - layoutPostProcessing(moduleOp); - // Post-pass cleanup of inplaceable and buffer_layout attributes. moduleOp.walk([&](func::FuncOp op) { for (BlockArgument bbArg : op.getArguments())