diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -53,6 +53,11 @@ /// the same op, which also has the same type as the given op, that defining /// op is used as a replacement. /// + /// A "failure" return value indicates that no replacement operation could be + /// found. A "nullptr" return value indicates that no replacement op is needed + /// (e.g., handle is dead or was consumed) and that the payload op should + /// be dropped from the mapping. + /// /// Example: A tracked "linalg.generic" with two results is replaced with two /// values defined by (another) "linalg.generic". It is reasonable to assume /// that the replacement "linalg.generic" represents the same "computation". @@ -91,8 +96,8 @@ /// /// Derived classes may override `findReplacementOp` to specify custom /// replacement rules. - virtual Operation *findReplacementOp(Operation *op, - ValueRange newValues) const; + virtual FailureOr findReplacementOp(Operation *op, + ValueRange newValues) const; /// Notify the listener that the pattern failed to match the given operation, /// and provide a callback to populate a diagnostic with the reason why the diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -70,7 +70,7 @@ return defOp; } -Operation * +FailureOr transform::TrackingListener::findReplacementOp(Operation *op, ValueRange newValues) const { assert(op->getNumResults() == newValues.size() && @@ -81,7 +81,7 @@ // If the replacement values belong to different ops, drop the mapping. Operation *defOp = getCommonDefiningOp(values); if (!defOp) - return nullptr; + return failure(); // If the defining op has the same type, we take it as a replacement. if (op->getName() == defOp->getName()) @@ -108,7 +108,7 @@ } } while (!values.empty()); - return nullptr; + return failure(); } LogicalResult transform::TrackingListener::notifyMatchFailure( @@ -173,12 +173,16 @@ return; } - Operation *replacement = findReplacementOp(op, newValues); + FailureOr replacement = findReplacementOp(op, newValues); // If the op is tracked but no replacement op was found, send a // notification. - if (!replacement) + if (failed(replacement)) { notifyPayloadReplacementNotFound(op, newValues); - (void)replacePayloadOp(op, replacement); + (void)replacePayloadOp(op, nullptr); + return; + } + + (void)replacePayloadOp(op, *replacement); } transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -302,7 +302,10 @@ // Expose `findReplacementOp` as a public function, so that it can be tested. Operation *getReplacementOp(Operation *op, ValueRange newValues) const { - return findReplacementOp(op, newValues); + auto replacementOp = findReplacementOp(op, newValues); + if (failed(replacementOp)) + return nullptr; + return *replacementOp; } }; } // namespace diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -696,15 +696,15 @@ using transform::TrackingListener::TrackingListener; protected: - Operation *findReplacementOp(Operation *op, - ValueRange newValues) const override { + FailureOr + findReplacementOp(Operation *op, ValueRange newValues) const override { if (newValues.size() != 1) - return nullptr; + return failure(); Operation *replacement = newValues[0].getDefiningOp(); if (!replacement) - return nullptr; + return failure(); if (replacement->getName().getStringRef() != "test.update_mapping") - return nullptr; + return failure(); return replacement; } };