diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -27,6 +27,7 @@ class AnalysisState; struct BufferizationState; struct BufferizationOptions; +class OpFilter; /// A helper type converter class that automatically populates the relevant /// materializations and type conversions for bufferization. @@ -84,8 +85,8 @@ /// Reuse an existing `BufferizationState`. /// /// Note: This function overload is useful for extending the bufferization. -LogicalResult bufferizeOp(Operation *op, - BufferizationState &bufferizationState); +LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState, + const OpFilter *opFilter = nullptr); /// Finalize all buffer allocations: Create alloc/dealloc ops as specified by /// the bufferization options. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -345,9 +345,10 @@ public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, - const BufferizationOptions &options) + const BufferizationOptions &options, + const OpFilter *opFilter) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), - options(options) {} + options(options), opFilter(opFilter) {} protected: void notifyOperationRemoved(Operation *op) override { @@ -370,10 +371,18 @@ if (isa(op)) return; + // Skip non-tensor ops. + if (!hasTensorSemantics(op)) + return; + + // Skip ops that are not allowed. + if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) + return; + // Adding new bufferizable ops is not allowed during bufferization. Such ops // would not be analyzed and can lead to surprising behavior. - assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) && - "creating new tensor ops is not allowed during bufferization"); + llvm_unreachable( + "creating new tensor ops is not allowed during bufferization"); } private: @@ -387,12 +396,14 @@ /// Used for debug modes. LLVM_ATTRIBUTE_UNUSED const BufferizationOptions &options; + + const OpFilter *opFilter; }; } // namespace -LogicalResult -bufferization::bufferizeOp(Operation *op, - BufferizationState &bufferizationState) { +LogicalResult bufferization::bufferizeOp(Operation *op, + BufferizationState &bufferizationState, + const OpFilter *opFilter) { const auto &options = bufferizationState.getOptions(); assert(options.unknownTypeConversion != BufferizationOptions::LayoutMapOption::InferLayoutMap && @@ -420,7 +431,7 @@ // Bufferize all ops. BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, - bufferizationState.getOptions()); + bufferizationState.getOptions(), opFilter); for (unsigned i = 0; i < worklist.size(); ++i) { Operation *op = worklist[i]; // Skip ops that were erased. @@ -430,6 +441,8 @@ auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) continue; + if (opFilter && !opFilter->isOpAllowed(op)) + continue; // Skip ops that no longer have tensor semantics. if (!hasTensorSemantics(op)) continue; @@ -462,6 +475,8 @@ // Continue ops that are not allowed. if (!options.isOpAllowed(op)) continue; + if (opFilter && !opFilter->isOpAllowed(op)) + continue; // Ops without any uses and no side effects will fold away. if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) continue;