diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -179,6 +179,10 @@ /// Initializer function for dialect-specific analysis state. using DialectStateInitFn = std::function()>; + /// Tensor -> MemRef type converter. + /// Parameters: Value, memory space, bufferization options + using UnknownTypeConverterFn = std::function; enum class LayoutMapOption : int8_t { InferLayoutMap = 0, @@ -266,21 +270,11 @@ LayoutMapOption functionBoundaryTypeConversion = LayoutMapOption::InferLayoutMap; - /// This flag controls buffer types on unknown ops (to_memref wrappers) and in - /// other cases where a precise memref type cannot be inferred (e.g., the - /// bufferization of "tensor.cast"). - /// - /// * InferLayoutMap: This option is invalid and cannot be used. - /// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully - /// dynamic layout maps after bufferization. This option is most efficient - /// because any layout map can be casted to a fully dynamic one. - /// * IdentityLayoutMap: Assume that unknown ops have results with static - /// identity layout (i.e., no layout map) after bufferization. This option - /// introduces additional buffer allocs and copies if the unknown op is - /// eventually bufferized to an op that returns a buffer with non-identity - /// layout. - LayoutMapOption unknownTypeConversion = - LayoutMapOption::FullyDynamicLayoutMap; + /// Type converter from tensors to memrefs. This type converter is used if no + /// memref type could be inferred during bufferization. By default, a type + /// converter that returns a memref type with a fully dynamic layout map is + /// used. + UnknownTypeConverterFn unknownTypeConverterFn = nullptr; /// Specifies whether dealloc ops should be generated along with alloc ops. If /// not, new memory allocations will leak. @@ -505,20 +499,19 @@ return newOp; } -/// Return a MemRefType to which the `tensorType` can be bufferized. +/// Return a MemRefType to which the type of the given value can be bufferized. /// /// If possible, op bufferization implementations should not use this function /// and instead infer precise memref types for tensor results by themselves. /// -/// Unless a layout map was specified, `options.unknownTypeConverter` determines -/// what kind of layout map will be used. For best composability (without -/// copies), the fully dynamic layout map is used by default. +/// Unless a layout map was specified, `options.unknownTypeConverterFn` +/// determines what kind of layout map will be used. For best composability +/// (without copies), the fully dynamic layout map is used by default. /// /// Note: Canonicalization patterns could clean up layout maps and infer more /// precise layout maps after bufferization. However, many possible /// canonicalizations are currently not implemented. -BaseMemRefType getMemRefType(TensorType tensorType, - const BufferizationOptions &options, +BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout = {}, unsigned memorySpace = 0); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -351,8 +351,9 @@ /*defaultImplementation=*/[{ assert(bbArg.getOwner()->getParentOp() == $_op && "bbArg must belong to this op"); - auto tensorType = bbArg.getType().cast(); - return bufferization::getMemRefType(tensorType, options); + assert(bbArg.getType().isa() && + "expected tensor type"); + return bufferization::getMemRefType(bbArg, options); }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -222,8 +222,17 @@ // BufferizationOptions //===----------------------------------------------------------------------===// +/// Default unknown type converter: Use a fully dynamic layout map. +static BaseMemRefType +defaultUnknownTypeConverter(Value value, unsigned memorySpace, + const BufferizationOptions &options) { + return getMemRefTypeWithFullyDynamicLayout(value.getType().cast(), + memorySpace); +} + // Default constructor for BufferizationOptions. -BufferizationOptions::BufferizationOptions() = default; +BufferizationOptions::BufferizationOptions() + : unknownTypeConverterFn(defaultUnknownTypeConverter) {} bool BufferizationOptions::isOpAllowed(Operation *op) const { // Special case: If function boundary bufferization is deactivated, do not @@ -528,8 +537,7 @@ /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { - auto tensorType = value.getType().dyn_cast(); - assert(tensorType && "unexpected non-tensor type"); + assert(value.getType().isa() && "unexpected non-tensor type"); Operation *op = getOwnerOfValue(value); // ToTensorOp: Take buffer type directly from the op. @@ -566,7 +574,7 @@ if (!memorySpace.hasValue()) return op->emitError("could not infer memory space"); - return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace); + return getMemRefType(value, options, /*layout=*/{}, *memorySpace); } void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, @@ -652,10 +660,11 @@ return isa(bbArg.getOwner()->getParentOp()); } -BaseMemRefType bufferization::getMemRefType(TensorType tensorType, +BaseMemRefType bufferization::getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout, unsigned memorySpace) { + auto tensorType = value.getType().cast(); auto memorySpaceAttr = IntegerAttr::get( IntegerType::get(tensorType.getContext(), 64), memorySpace); @@ -674,17 +683,7 @@ memorySpaceAttr); } - // Case 3: Configured with "fully dynamic layout maps". - if (options.unknownTypeConversion == - BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap) - return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); - - // Case 4: Configured with "static identity layout maps". - if (options.unknownTypeConversion == - BufferizationOptions::LayoutMapOption::IdentityLayoutMap) - return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); - - llvm_unreachable("InferLayoutMap is an invalid option"); + return options.unknownTypeConverterFn(value, memorySpace, options); } BaseMemRefType diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -192,8 +192,26 @@ opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; - opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion); + // Configure type converter. + BufferizationOptions::LayoutMapOption unknownTypeConversionOption = + parseLayoutMapOption(unknownTypeConversion); + opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace, + const BufferizationOptions &options) { + auto tensorType = value.getType().cast(); + if (unknownTypeConversionOption == + BufferizationOptions::LayoutMapOption::IdentityLayoutMap) + return bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorType, memorySpace); + assert( + unknownTypeConversionOption == + BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap && + "invalid layout map option"); + return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace); + }; + + // Configure op filter. OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { // Filter may be specified via options. @@ -372,10 +390,6 @@ const BufferizationOptions &options, bool copyBeforeWrite, const OpFilter *opFilter) { - assert(options.unknownTypeConversion != - BufferizationOptions::LayoutMapOption::InferLayoutMap && - "invalid layout map option"); - if (copyBeforeWrite) { AnalysisState state(options); if (failed(insertTensorCopies(op, state))) @@ -474,8 +488,11 @@ options.allowUnknownOps = true; options.createDeallocs = false; options.enforceAliasingInvariants = false; - options.unknownTypeConversion = - BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + options.unknownTypeConverterFn = [](Value value, unsigned memorySpace, + const BufferizationOptions &options) { + return getMemRefTypeWithStaticIdentityLayout( + value.getType().cast(), memorySpace); + }; options.opFilter.allowDialect(); return options; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -67,7 +67,7 @@ // Compute the new memref type. Type resultMemRefType = - getMemRefType(resultTensorType, options, layout, + getMemRefType(castOp.getResult(), options, layout, sourceMemRefType.getMemorySpaceAsInt()); // Replace the op with a memref.cast. @@ -800,9 +800,8 @@ getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); - auto resultTensorType = reshapeOp.getResult().getType().cast(); auto resultMemRefType = getMemRefType( - resultTensorType, options, /*layout=*/{}, + reshapeOp.getResult(), options, /*layout=*/{}, srcBuffer->getType().cast().getMemorySpaceAsInt()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);