diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -48,8 +48,8 @@ protected: /// Return a replacement payload op for the given op, which is going to be /// replaced with the given values. By default, if all values are defined by - /// the same newly-created op, which also has the same type as the given op, - /// that defining op is used as a replacement. + /// the same op, which also has the same type as the given op, that defining + /// op is used as a replacement. virtual Operation *findReplacementOp(Operation *op, ValueRange newValues) const; @@ -66,22 +66,14 @@ virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values) {} - /// Return "true" if the given op is a new op. - bool isNewOp(Operation *op) const; - /// Return the single op that defines all given values (if any). static Operation *getCommonDefiningOp(ValueRange values); private: - void notifyOperationInserted(Operation *op) override; - void notifyOperationRemoved(Operation *op) override; void notifyOperationReplaced(Operation *op, ValueRange newValues) override; - /// Ops that were newly created during the transform. - DenseMap> newOps; - /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -180,20 +180,9 @@ if (op->getName() != defOp->getName()) return nullptr; - // If the replacement op is not a new op, drop the mapping. - if (!isNewOp(defOp)) - return nullptr; - return defOp; } -bool transform::TrackingListener::isNewOp(Operation *op) const { - auto it = newOps.find(op->getName()); - if (it == newOps.end()) - return false; - return it->second.contains(op); -} - LogicalResult transform::TrackingListener::notifyMatchFailure( Location loc, function_ref reasonCallback) { LLVM_DEBUG({ @@ -204,17 +193,9 @@ return failure(); } -void transform::TrackingListener::notifyOperationInserted(Operation *op) { - newOps[op->getName()].insert(op); -} - void transform::TrackingListener::notifyOperationRemoved(Operation *op) { // TODO: Walk can be removed when D144193 has landed. op->walk([&](Operation *op) { - // Keep set of new ops up-to-date. - auto it = newOps.find(op->getName()); - if (it != newOps.end()) - it->second.erase(op); // Remove mappings for result values. for (OpResult value : op->getResults()) (void)replacePayloadValue(value, nullptr);