diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -78,9 +78,12 @@ BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Operation *op) const { - if (isOpAllowed(op)) - return dyn_cast(op); - return nullptr; + auto bufferizableOp = dyn_cast(op); + if (!bufferizableOp) + return nullptr; + if (!isOpAllowed(op)) + return nullptr; + return bufferizableOp; } BufferizableOpInterface diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -302,14 +302,16 @@ public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, - SmallVector &worklist) + const BufferizationOptions &options) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), - worklist(worklist) {} + options(options) {} protected: void notifyOperationRemoved(Operation *op) override { IRRewriter::notifyOperationRemoved(op); erasedOps.insert(op); + // Erase if present. + toMemrefOps.erase(op); } void notifyOperationInserted(Operation *op) override { @@ -325,9 +327,10 @@ if (isa(op)) return; - // A new bufferizable op was inserted. Add it to the worklist. - if (hasTensorSemantics(op)) - worklist.push_back(op); + // Adding new bufferizable ops is not allowed during bufferization. Such ops + // would not be analyzed and can lead to surprising behavior. + assert((!hasTensorSemantics(op) || !options.isOpAllowed(op)) && + "creating new tensor ops is not allowed during bufferization"); } private: @@ -337,8 +340,8 @@ /// A set of all to_memref ops. DenseSet &toMemrefOps; - /// The list of bufferizable ops. - SmallVector &worklist; + /// The bufferization options. + const BufferizationOptions &options; }; } // namespace @@ -373,18 +376,18 @@ // Bufferize all ops. BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, - worklist); + bufferizationState.getOptions()); for (unsigned i = 0; i < worklist.size(); ++i) { Operation *op = worklist[i]; // Skip ops that were erased. if (erasedOps.contains(op)) continue; - // Skip ops that are not bufferizable. - auto bufferizableOp = dyn_cast(op); + // Skip ops that are not bufferizable or not allowed. + auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) continue; - // Continue ops that are not allowed. - if (!options.isOpAllowed(op)) + // Skip ops that no longer have tensor semantics. + if (!hasTensorSemantics(op)) continue; // Bufferize the op. rewriter.setInsertionPoint(op); @@ -393,8 +396,6 @@ // Fold all to_memref(to_tensor(x)) pairs. for (Operation *op : toMemrefOps) { - if (erasedOps.contains(op)) - continue; rewriter.setInsertionPoint(op); (void)bufferization::foldToMemrefToTensorPair(rewriter, cast(op));