diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -17,6 +17,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/ExtensibleRTTI.h" namespace mlir { class BlockAndValueMapping; @@ -240,13 +241,21 @@ /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. -struct BufferizationState { - BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns) - : aliasInfo(moduleOp), allocationFns(allocationFns) {} +/// +/// Note: BufferizationState uses RTTIExtends, so that extensions of the +/// bufferization (such as ModuleBufferization) can define their own derived +/// structs without having to mention their type IDs/enums here. +struct BufferizationState + : public llvm::RTTIExtends { + BufferizationState(Operation *rootOp, AllocationCallbacks &allocationFns) + : aliasInfo(rootOp), allocationFns(allocationFns) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; + /// Type ID for RTTI. + static char ID; + /// Map tensor values to memref buffers. void mapBuffer(ValueRange tensors, ValueRange buffers); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -557,3 +557,5 @@ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), stridedLayout, addressSpace); } + +char mlir::linalg::comprehensive_bufferize::BufferizationState::ID = 0; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -28,12 +28,19 @@ namespace { /// A specialization of BufferizationState that keeps track of additional /// state required for bufferization of function boundaries. -struct ModuleBufferizationState : public BufferizationState { - using BufferizationState::BufferizationState; +struct ModuleBufferizationState + : public llvm::RTTIExtends { + using llvm::RTTIExtends::RTTIExtends; + + /// Type ID for RTTI. + static char ID; /// A map for looking up bufferized function types. DenseMap bufferizedFunctionTypes; }; + +char ModuleBufferizationState::ID = 0; } // namespace static bool isaTensor(Type t) { return t.isa(); } @@ -577,7 +584,9 @@ llvm::append_range(argumentTypes, hoistedArgs.getTypes()); // Get the bufferized FunctionType for funcOp or construct it if not yet // available. - // TODO: Assert that `state` is a ModuleBufferizationState. + assert(isa(state) && + "cannot use CallOp bufferization interface impl without " + "ModuleBufferization"); FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( funcOp, argumentTypes, resultTypes, static_cast(state).bufferizedFunctionTypes);