diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -297,6 +297,11 @@ /// bufferization is necessary. Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); +/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this +/// function returns immediately. Otherwise, it calls the `bufferize` interface +/// method of `BufferizableOpInterface`. +LogicalResult bufferizeOp(Operation *op, BufferizationState &state); + /// PostAnalysisSteps can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used /// implement custom dialect-specific optimizations. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -24,9 +24,6 @@ /// Return default allocation callbacks. std::unique_ptr defaultAllocationCallbacks(); -/// Bufferize one particular op. -LogicalResult bufferizeOp(Operation *op, BufferizationState &state); - /// Register external models implemented for the `BufferizableOpInterface`. void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" @@ -390,6 +391,31 @@ return operandBuffer; } +LogicalResult +mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op, + BufferizationState &state) { + OpBuilder b(op->getContext()); + + // Skip BufferCast and TensorLoad ops. + if (isa(op)) + return success(); + + // Check if op has tensor results or operands. + auto isaTensor = [](Type t) { return t.isa(); }; + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); + bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); + if (!hasTensorResult && !hasTensorOperand) + return success(); + + // Bufferize using `BufferizableOpInterface`. + b.setInsertionPoint(op); + if (auto bufferizableOp = dyn_cast(op)) + return bufferizableOp.bufferize(b, state); + + // Other op with tensors. No bufferization method specified. + return op->emitError() << "unsupported op with tensors"; +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRMemRef ) add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -927,30 +927,6 @@ // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -LogicalResult -mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op, - BufferizationState &state) { - OpBuilder b(op->getContext()); - - // Skip BufferCast and TensorLoad ops. - if (isa(op)) - return success(); - - // Check if op has tensor results or operands. - auto isaTensor = [](Type t) { return t.isa(); }; - bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); - bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); - if (!hasTensorResult && !hasTensorOperand) - return success(); - - // Bufferize using `BufferizableOpInterface`. - if (auto bufferizableOp = dyn_cast(op)) - return bufferizableOp.bufferize(b, state); - - // Other op with tensors. No bufferization method specified. - return op->emitError() << "unsupported op with tensors"; -} - static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp, BufferizationState &state) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6299,6 +6299,7 @@ deps = [ ":BufferizableOpInterfaceIncGen", ":IR", + ":MemRefDialect", ":Support", "//llvm:Support", ],