diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h @@ -22,8 +22,7 @@ /// replacements. class TrackingListener : public transform::TrackingListener { public: - explicit TrackingListener(transform::TransformState &state) - : transform::TrackingListener(state) {} + using transform::TrackingListener::TrackingListener; protected: Operation *findReplacementOp(Operation *op, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -39,8 +39,9 @@ class TrackingListener : public RewriterBase::Listener, public TransformState::Extension { public: - explicit TrackingListener(TransformState &state) - : TransformState::Extension(state) {} + /// Create a new TrackingListener for usage in the specified transform op. + explicit TrackingListener(TransformState &state, TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) {} protected: /// Return a replacement payload op for the given op, which is going to be @@ -71,6 +72,9 @@ /// Ops that were newly created during the transform. DenseMap> newOps; + + /// The transform op in which this TrackingListener is used. + TransformOpInterface transformOp; }; } // namespace transform diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1764,7 +1764,7 @@ transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - TrackingListener listener(state); + TrackingListener listener(state, *this); IRRewriter rewriter(target->getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), @@ -3014,7 +3014,7 @@ if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); - TrackingListener listener(state); + TrackingListener listener(state, *this); GreedyRewriteConfig config; config.listener = &listener; if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config))) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -200,6 +200,19 @@ }); } +/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors +/// properly dominates `b` and `b` is not inside `a`. +static bool happensBefore(Operation *a, Operation *b) { + do { + if (a->isProperAncestor(b)) + return false; + if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { + return a->isBeforeInBlock(bAncestor); + } + } while ((a = a->getParentOp())); + return false; +} + void transform::TrackingListener::notifyOperationReplaced( Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && @@ -210,13 +223,30 @@ (void)replacePayloadValue(oldValue, newValue); // Replace op handle. - Operation *replacement = findReplacementOp(op, newValues); - if (succeeded(replacePayloadOp(op, replacement))) { - // If the op is tracked but no replacement op was found, send a - // notification. - if (!replacement) - notifyPayloadReplacementNotFound(op, newValues); + SmallVector opHandles; + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { + // Op is not tracked. + return; + } + auto hasAliveUser = [&]() { + for (Value v : opHandles) + for (Operation *user : v.getUsers()) + if (!happensBefore(user, transformOp)) + return true; + return false; + }; + if (!hasAliveUser()) { + // The op is tracked but the corresponding handles are dead. + (void)replacePayloadOp(op, nullptr); + return; } + + Operation *replacement = findReplacementOp(op, newValues); + // If the op is tracked but no replacement op was found, send a + // notification. + if (!replacement) + notifyPayloadReplacementNotFound(op, newValues); + (void)replacePayloadOp(op, replacement); } //===----------------------------------------------------------------------===// @@ -339,7 +369,7 @@ if (!failed) { // We will be using the clones, so cancel their scheduled deletion. deleteClones.release(); - TrackingListener listener(state); + TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); for (const auto &kvp : llvm::zip(originals, clones)) { Operation *original = std::get<0>(kvp);