diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -50,6 +50,12 @@ virtual Operation *findReplacementOp(Operation *op, ValueRange newValues) const; + /// This function is called when a tracked payload op is dropped because no + /// replacement op was found. Derived classes can implement this function for + /// custom error handling. + virtual void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) const {} + /// Return "true" if the given op is a new op. bool isNewOp(Operation *op) const; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -204,15 +204,19 @@ Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "invalid number of replacement values"); - if (op->getNumResults() == 0) - return; // Replace value handles. for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) (void)replacePayloadValue(oldValue, newValue); // Replace op handle. - (void)replacePayloadOp(op, findReplacementOp(op, newValues)); + Operation *replacement = findReplacementOp(op, newValues); + if (succeeded(replacePayloadOp(op, replacement))) { + // If the op is tracked but no replacement op was found, send a + // notification. + if (!replacement) + notifyPayloadReplacementNotFound(op, newValues); + } } //===----------------------------------------------------------------------===//