diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -19,6 +19,9 @@ namespace mlir { class OpBuilder; +namespace func { +class FuncOp; +} namespace bufferization { @@ -250,6 +253,11 @@ /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; /// Tensor -> MemRef type converter. + /// Parameters: Value, memory space, func op, bufferization options + using FunctionArgTypeConverterFn = + std::function; + /// Tensor -> MemRef type converter. /// Parameters: Value, memory space, bufferization options using UnknownTypeConverterFn = std::function; @@ -313,7 +321,8 @@ /// OpOperands out-of-place. bool enforceAliasingInvariants = true; - /// This flag controls buffer types on function signatures. + /// This function controls buffer types on function signatures. Sets + /// `functionArgTypeConverterFn` and `inferFunctionResultLayout` accordingly. /// /// * InferLayoutMap: All function parameter types have a fully dynamic layout /// map, but function result types are inferred from the body of the @@ -326,13 +335,25 @@ /// additional buffer allocs and copies because layout maps cannot be casted /// away. /// - /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. - /// /// Note: Inferred layout maps may not be desireable when interacting with /// external functions, because the generated function signatures will be less /// predictable. - LayoutMapOption functionBoundaryTypeConversion = - LayoutMapOption::InferLayoutMap; + void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); + + /// Type converter from tensors to memrefs. This type converter is used to + /// determine bufferized function argument types. By default, a type + /// converter that returns a memref type with a fully dynamic layout map is + /// used. + /// + /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. + FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; + + /// If true, function result types are inferred from the body of the function. + /// Otherwise, function result type is determined by + /// `functionArgTypeConverterFn`. + /// + /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. + bool inferFunctionResultLayout = true; /// Type converter from tensors to memrefs. This type converter is used if no /// memref type could be inferred during bufferization. By default, a type diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -322,17 +322,29 @@ // BufferizationOptions //===----------------------------------------------------------------------===// +namespace { + +/// Default function arg type converter: Use a fully dynamic layout map. +BaseMemRefType +defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, + func::FuncOp funcOp, + const BufferizationOptions &options) { + return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); +} /// Default unknown type converter: Use a fully dynamic layout map. -static BaseMemRefType +BaseMemRefType defaultUnknownTypeConverter(Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithFullyDynamicLayout(value.getType().cast(), memorySpace); } +}; // namespace + // Default constructor for BufferizationOptions. BufferizationOptions::BufferizationOptions() - : unknownTypeConverterFn(defaultUnknownTypeConverter) {} + : functionArgTypeConverterFn(defaultFunctionArgTypeConverter), + unknownTypeConverterFn(defaultUnknownTypeConverter) {} bool BufferizationOptions::isOpAllowed(Operation *op) const { // Special case: If function boundary bufferization is deactivated, do not @@ -362,6 +374,21 @@ return nullptr; } +void BufferizationOptions::setFunctionBoundaryTypeConversion( + LayoutMapOption layoutMapOption) { + functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + func::FuncOp funcOp, + const BufferizationOptions &options) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace); + return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace); + }; + inferFunctionResultLayout = + layoutMapOption == LayoutMapOption::InferLayoutMap; +} + //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -38,8 +38,8 @@ options.testAnalysisOnly = getTestAnalysisOnly(); options.printConflicts = getPrintConflicts(); if (getFunctionBoundaryTypeConversion().has_value()) - options.functionBoundaryTypeConversion = - *getFunctionBoundaryTypeConversion(); + options.setFunctionBoundaryTypeConversion( + *getFunctionBoundaryTypeConversion()); ArrayRef payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -208,8 +208,8 @@ opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic); opt.copyBeforeWrite = copyBeforeWrite; opt.createDeallocs = createDeallocs; - opt.functionBoundaryTypeConversion = - parseLayoutMapOption(functionBoundaryTypeConversion); + opt.setFunctionBoundaryTypeConversion( + parseLayoutMapOption(functionBoundaryTypeConversion)); if (mustInferMemorySpace) opt.defaultMemorySpace = std::nullopt; opt.printConflicts = printConflicts; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -55,8 +55,7 @@ /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be -/// specified by the user. If no layout map is specified, the default layout map -/// (as per `options.functionBoundaryTypeConversion`) is used. +/// specified by the user (as per `options.functionArgTypeConverterFn`). static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { @@ -64,17 +63,8 @@ funcOp.getFunctionType().getInput(index).dyn_cast(); assert(tensorType && "expected TensorType"); - BaseMemRefType memrefType; - if (options.functionBoundaryTypeConversion == - LayoutMapOption::IdentityLayoutMap) { - memrefType = getMemRefTypeWithStaticIdentityLayout( - tensorType, *options.defaultMemorySpace); - } else { - // Note: Layout maps on function parameters cannot be inferred. The best we - // can do at the moment is "fully dynamic". - memrefType = getMemRefTypeWithFullyDynamicLayout( - tensorType, *options.defaultMemorySpace); - } + BaseMemRefType memrefType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpace, funcOp, options); auto layoutAttr = funcOp.getArgAttrOfType( index, BufferizationDialect::kBufferLayoutAttrName); @@ -423,16 +413,10 @@ continue; } - BaseMemRefType resultType; - if (options.functionBoundaryTypeConversion == - LayoutMapOption::IdentityLayoutMap) { - resultType = getMemRefTypeWithStaticIdentityLayout( - tensorType, *options.defaultMemorySpace); - } else { - // Note: If `InferLayoutMap`, cast are later folded away. - resultType = getMemRefTypeWithFullyDynamicLayout( - tensorType, *options.defaultMemorySpace); - } + // Note: If `inferFunctionResultLayout = true`, cast are later folded + // away. + BaseMemRefType resultType = options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpace, funcOp, options); Value toMemrefOp = rewriter.create( loc, resultType, returnVal); returnValues.push_back(toMemrefOp); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -433,8 +433,7 @@ /*opFilter=*/nullptr, statistics))) return failure(); // Change buffer return types to more precise layout maps. - if (options.functionBoundaryTypeConversion == - LayoutMapOption::InferLayoutMap) + if (options.inferFunctionResultLayout) foldMemRefCasts(funcOp); } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -37,7 +37,7 @@ // TODO(springerm): To spot memory leaks more easily, returning dense allocs // should be disallowed. options.allowReturnAllocs = true; - options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap; + options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout(