diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" #include "mlir/Dialect/Transform/Utils/RaggedArray.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" @@ -69,6 +70,13 @@ SmallVector getConsumedHandleOpOperands(transform::TransformOpInterface transformOp); } // namespace detail +} // namespace transform +} // namespace mlir + +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" + +namespace mlir { +namespace transform { /// Options controlling the application of transform operations by the /// TransformState. @@ -839,6 +847,125 @@ return RegionScope(*this, region, RegionScope::Isolated()); } +/// A listener that updates a TransformState based on IR modifications. This +/// listener can be used during a greedy pattern rewrite to keep the transform +/// state up-to-date. +class TrackingListener : public RewriterBase::Listener, + public TransformState::Extension { +public: + /// Create a new TrackingListener for usage in the specified transform op. + explicit TrackingListener(TransformState &state, TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) {} + +protected: + /// Return a replacement payload op for the given op, which is going to be + /// replaced with the given values. By default, if all values are defined by + /// the same op, which also has the same type as the given op, that defining + /// op is used as a replacement. + /// + /// A "failure" return value indicates that no replacement operation could be + /// found. A "nullptr" return value indicates that no replacement op is needed + /// (e.g., handle is dead or was consumed) and that the payload op should + /// be dropped from the mapping. + /// + /// Example: A tracked "linalg.generic" with two results is replaced with two + /// values defined by (another) "linalg.generic". It is reasonable to assume + /// that the replacement "linalg.generic" represents the same "computation". + /// Therefore, the payload op mapping is updated to the defining op of the + /// replacement values. + /// + /// Counter Example: A "linalg.generic" is replaced with values defined by an + /// "scf.for". Without further investigation, the relationship between the + /// "linalg.generic" and the "scf.for" is unclear. They may not represent the + /// same computation; e.g., there may be tiled "linalg.generic" inside the + /// loop body that represents the original computation. Therefore, the + /// TrackingListener is conservative by default: it drops the mapping and + /// triggers the "payload replacement not found" notification. + /// + /// If no replacement op could be found according to the rules mentioned + /// above, this function tries to skip over cast-like ops that implement + /// `CastOpInterface`. + /// + /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", + /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is + /// reasonable to assume that the wrapped "linalg.generic" represents the same + /// computation as the original "linalg.generic". The mapping is updated + /// accordingly. + /// + /// Certain ops (typically also metadata-only ops) are not considered casts, + /// but should be skipped nonetheless. Such ops should implement + /// `FindPayloadReplacementOpInterface` to specify with which operands the + /// lookup should continue. + /// + /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", + /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but + /// not cast. (Implementing `CastOpInterface` would be incorrect and cause + /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface` + /// implementation, the replacement op lookup continues with the wrapped + /// "linalg.generic" and the mapping is updated accordingly. + /// + /// Derived classes may override `findReplacementOp` to specify custom + /// replacement rules. + virtual FailureOr findReplacementOp(Operation *op, + ValueRange newValues) const; + + /// Notify the listener that the pattern failed to match the given operation, + /// and provide a callback to populate a diagnostic with the reason why the + /// failure occurred. + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) override; + + /// This function is called when a tracked payload op is dropped because no + /// replacement op was found. Derived classes can implement this function for + /// custom error handling. + virtual void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) {} + + /// Return the single op that defines all given values (if any). + static Operation *getCommonDefiningOp(ValueRange values); + + /// Return the transform op in which this TrackingListener is used. + TransformOpInterface getTransformOp() const { return transformOp; } + +private: + void notifyOperationRemoved(Operation *op) override; + + void notifyOperationReplaced(Operation *op, ValueRange newValues) override; + + /// The transform op in which this TrackingListener is used. + TransformOpInterface transformOp; +}; + +/// A specialized listener that keeps track of cases in which no replacement +/// payload could be found. The error state of this listener must be checked +/// before the end of its lifetime. +class ErrorCheckingTrackingListener : public TrackingListener { +public: + using transform::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override; + + /// Check and return the current error state of this listener. Afterwards, + /// resets the error state to "success". + DiagnosedSilenceableFailure checkAndResetError(); + + /// Return "true" if this tracking listener had a failure. + bool failed() const; + +protected: + void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) override; + +private: + /// The error state of this listener. "Success" indicates that no error + /// happened so far. + DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); + + /// The number of errors that have been encountered. + int64_t errorCounter = 0; +}; + /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1084,14 +1211,6 @@ } }; -} // namespace transform -} // namespace mlir - -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" - -namespace mlir { -namespace transform { - /// A single result of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation. using ApplyToEachResult = MappedValue; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -37,125 +37,6 @@ ::llvm::function_ref; -/// A listener that updates a TransformState based on IR modifications. This -/// listener can be used during a greedy pattern rewrite to keep the transform -/// state up-to-date. -class TrackingListener : public RewriterBase::Listener, - public TransformState::Extension { -public: - /// Create a new TrackingListener for usage in the specified transform op. - explicit TrackingListener(TransformState &state, TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) {} - -protected: - /// Return a replacement payload op for the given op, which is going to be - /// replaced with the given values. By default, if all values are defined by - /// the same op, which also has the same type as the given op, that defining - /// op is used as a replacement. - /// - /// A "failure" return value indicates that no replacement operation could be - /// found. A "nullptr" return value indicates that no replacement op is needed - /// (e.g., handle is dead or was consumed) and that the payload op should - /// be dropped from the mapping. - /// - /// Example: A tracked "linalg.generic" with two results is replaced with two - /// values defined by (another) "linalg.generic". It is reasonable to assume - /// that the replacement "linalg.generic" represents the same "computation". - /// Therefore, the payload op mapping is updated to the defining op of the - /// replacement values. - /// - /// Counter Example: A "linalg.generic" is replaced with values defined by an - /// "scf.for". Without further investigation, the relationship between the - /// "linalg.generic" and the "scf.for" is unclear. They may not represent the - /// same computation; e.g., there may be tiled "linalg.generic" inside the - /// loop body that represents the original computation. Therefore, the - /// TrackingListener is conservative by default: it drops the mapping and - /// triggers the "payload replacement not found" notification. - /// - /// If no replacement op could be found according to the rules mentioned - /// above, this function tries to skip over cast-like ops that implement - /// `CastOpInterface`. - /// - /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", - /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is - /// reasonable to assume that the wrapped "linalg.generic" represents the same - /// computation as the original "linalg.generic". The mapping is updated - /// accordingly. - /// - /// Certain ops (typically also metadata-only ops) are not considered casts, - /// but should be skipped nonetheless. Such ops should implement - /// `FindPayloadReplacementOpInterface` to specify with which operands the - /// lookup should continue. - /// - /// Example: A tracked "linalg.generic" is replaced with "linalg.generic", - /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but - /// not cast. (Implementing `CastOpInterface` would be incorrect and cause - /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface` - /// implementation, the replacement op lookup continues with the wrapped - /// "linalg.generic" and the mapping is updated accordingly. - /// - /// Derived classes may override `findReplacementOp` to specify custom - /// replacement rules. - virtual FailureOr findReplacementOp(Operation *op, - ValueRange newValues) const; - - /// Notify the listener that the pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. - LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) override; - - /// This function is called when a tracked payload op is dropped because no - /// replacement op was found. Derived classes can implement this function for - /// custom error handling. - virtual void notifyPayloadReplacementNotFound(Operation *op, - ValueRange values) {} - - /// Return the single op that defines all given values (if any). - static Operation *getCommonDefiningOp(ValueRange values); - - /// Return the transform op in which this TrackingListener is used. - TransformOpInterface getTransformOp() const { return transformOp; } - -private: - void notifyOperationRemoved(Operation *op) override; - - void notifyOperationReplaced(Operation *op, ValueRange newValues) override; - - /// The transform op in which this TrackingListener is used. - TransformOpInterface transformOp; -}; - -/// A specialized listener that keeps track of cases in which no replacement -/// payload could be found. The error state of this listener must be checked -/// before the end of its lifetime. -class ErrorCheckingTrackingListener : public TrackingListener { -public: - using transform::TrackingListener::TrackingListener; - - ~ErrorCheckingTrackingListener() override; - - /// Check and return the current error state of this listener. Afterwards, - /// resets the error state to "success". - DiagnosedSilenceableFailure checkAndResetError(); - - /// Return "true" if this tracking listener had a failure. - bool failed() const; - -protected: - void notifyPayloadReplacementNotFound(Operation *op, - ValueRange values) override; - -private: - /// The error state of this listener. "Success" indicates that no error - /// happened so far. - DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); - - /// The number of errors that have been encountered. - int64_t errorCounter = 0; -}; - } // namespace transform } // namespace mlir diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -1155,6 +1157,175 @@ values[resultNumber].data() != nullptr; } +//===----------------------------------------------------------------------===// +// TrackingListener +//===----------------------------------------------------------------------===// + +Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { + Operation *defOp = nullptr; + for (Value v : values) { + // Skip empty values. + if (!v) + continue; + if (!defOp) { + defOp = v.getDefiningOp(); + continue; + } + if (defOp != v.getDefiningOp()) + return nullptr; + } + return defOp; +} + +FailureOr +transform::TrackingListener::findReplacementOp(Operation *op, + ValueRange newValues) const { + assert(op->getNumResults() == newValues.size() && + "invalid number of replacement values"); + SmallVector values(newValues.begin(), newValues.end()); + + do { + // If the replacement values belong to different ops, drop the mapping. + Operation *defOp = getCommonDefiningOp(values); + if (!defOp) + return failure(); + + // If the defining op has the same type, we take it as a replacement. + if (op->getName() == defOp->getName()) + return defOp; + + // Replacing an op with a constant-like equivalent is a common + // canonicalization. + if (defOp->hasTrait()) + return defOp; + + values.clear(); + + // Skip through ops that implement FindPayloadReplacementOpInterface. + if (auto findReplacementOpInterface = + dyn_cast(defOp)) { + values.assign(findReplacementOpInterface.getNextOperands()); + continue; + } + + // Skip through ops that implement CastOpInterface. + if (isa(defOp)) { + values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); + continue; + } + } while (!values.empty()); + + return failure(); +} + +LogicalResult transform::TrackingListener::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + DBGS() << "Match Failure : " << diag.str() << "\n"; + }); + return failure(); +} + +void transform::TrackingListener::notifyOperationRemoved(Operation *op) { + // TODO: Walk can be removed when D144193 has landed. + op->walk([&](Operation *op) { + // Remove mappings for result values. + for (OpResult value : op->getResults()) + (void)replacePayloadValue(value, nullptr); + // Remove mapping for op. + (void)replacePayloadOp(op, nullptr); + }); +} + +/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors +/// properly dominates `b` and `b` is not inside `a`. +static bool happensBefore(Operation *a, Operation *b) { + do { + if (a->isProperAncestor(b)) + return false; + if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { + return a->isBeforeInBlock(bAncestor); + } + } while ((a = a->getParentOp())); + return false; +} + +void transform::TrackingListener::notifyOperationReplaced( + Operation *op, ValueRange newValues) { + assert(op->getNumResults() == newValues.size() && + "invalid number of replacement values"); + + // Replace value handles. + for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) + (void)replacePayloadValue(oldValue, newValue); + + // Replace op handle. + SmallVector opHandles; + if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { + // Op is not tracked. + return; + } + auto hasAliveUser = [&]() { + for (Value v : opHandles) + for (Operation *user : v.getUsers()) + if (!happensBefore(user, transformOp)) + return true; + return false; + }; + if (!hasAliveUser()) { + // The op is tracked but the corresponding handles are dead. + (void)replacePayloadOp(op, nullptr); + return; + } + + FailureOr replacement = findReplacementOp(op, newValues); + // If the op is tracked but no replacement op was found, send a + // notification. + if (failed(replacement)) { + notifyPayloadReplacementNotFound(op, newValues); + (void)replacePayloadOp(op, nullptr); + return; + } + + (void)replacePayloadOp(op, *replacement); +} + +transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { + // The state of the ErrorCheckingTrackingListener must be checked and reset + // if there was an error. This is to prevent errors from accidentally being + // missed. + assert(status.succeeded() && "listener state was not checked"); +} + +DiagnosedSilenceableFailure +transform::ErrorCheckingTrackingListener::checkAndResetError() { + DiagnosedSilenceableFailure s = std::move(status); + status = DiagnosedSilenceableFailure::success(); + errorCounter = 0; + return s; +} + +bool transform::ErrorCheckingTrackingListener::failed() const { + return !status.succeeded(); +} + +void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values) { + if (status.succeeded()) { + status = emitSilenceableFailure( + getTransformOp(), "tracking listener failed to find replacement op"); + } + + status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; + for (auto &&[index, value] : llvm::enumerate(values)) + status.attachNote(value.getLoc()) + << "[" << errorCounter << "] replacement value " << index; + + ++errorCounter; +} + //===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -50,175 +50,6 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" -//===----------------------------------------------------------------------===// -// TrackingListener -//===----------------------------------------------------------------------===// - -Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { - Operation *defOp = nullptr; - for (Value v : values) { - // Skip empty values. - if (!v) - continue; - if (!defOp) { - defOp = v.getDefiningOp(); - continue; - } - if (defOp != v.getDefiningOp()) - return nullptr; - } - return defOp; -} - -FailureOr -transform::TrackingListener::findReplacementOp(Operation *op, - ValueRange newValues) const { - assert(op->getNumResults() == newValues.size() && - "invalid number of replacement values"); - SmallVector values(newValues.begin(), newValues.end()); - - do { - // If the replacement values belong to different ops, drop the mapping. - Operation *defOp = getCommonDefiningOp(values); - if (!defOp) - return failure(); - - // If the defining op has the same type, we take it as a replacement. - if (op->getName() == defOp->getName()) - return defOp; - - // Replacing an op with a constant-like equivalent is a common - // canonicalization. - if (defOp->hasTrait()) - return defOp; - - values.clear(); - - // Skip through ops that implement FindPayloadReplacementOpInterface. - if (auto findReplacementOpInterface = - dyn_cast(defOp)) { - values.assign(findReplacementOpInterface.getNextOperands()); - continue; - } - - // Skip through ops that implement CastOpInterface. - if (isa(defOp)) { - values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); - continue; - } - } while (!values.empty()); - - return failure(); -} - -LogicalResult transform::TrackingListener::notifyMatchFailure( - Location loc, function_ref reasonCallback) { - LLVM_DEBUG({ - Diagnostic diag(loc, DiagnosticSeverity::Remark); - reasonCallback(diag); - DBGS() << "Match Failure : " << diag.str() << "\n"; - }); - return failure(); -} - -void transform::TrackingListener::notifyOperationRemoved(Operation *op) { - // TODO: Walk can be removed when D144193 has landed. - op->walk([&](Operation *op) { - // Remove mappings for result values. - for (OpResult value : op->getResults()) - (void)replacePayloadValue(value, nullptr); - // Remove mapping for op. - (void)replacePayloadOp(op, nullptr); - }); -} - -/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors -/// properly dominates `b` and `b` is not inside `a`. -static bool happensBefore(Operation *a, Operation *b) { - do { - if (a->isProperAncestor(b)) - return false; - if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { - return a->isBeforeInBlock(bAncestor); - } - } while ((a = a->getParentOp())); - return false; -} - -void transform::TrackingListener::notifyOperationReplaced( - Operation *op, ValueRange newValues) { - assert(op->getNumResults() == newValues.size() && - "invalid number of replacement values"); - - // Replace value handles. - for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) - (void)replacePayloadValue(oldValue, newValue); - - // Replace op handle. - SmallVector opHandles; - if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) { - // Op is not tracked. - return; - } - auto hasAliveUser = [&]() { - for (Value v : opHandles) - for (Operation *user : v.getUsers()) - if (!happensBefore(user, transformOp)) - return true; - return false; - }; - if (!hasAliveUser()) { - // The op is tracked but the corresponding handles are dead. - (void)replacePayloadOp(op, nullptr); - return; - } - - FailureOr replacement = findReplacementOp(op, newValues); - // If the op is tracked but no replacement op was found, send a - // notification. - if (failed(replacement)) { - notifyPayloadReplacementNotFound(op, newValues); - (void)replacePayloadOp(op, nullptr); - return; - } - - (void)replacePayloadOp(op, *replacement); -} - -transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { - // The state of the ErrorCheckingTrackingListener must be checked and reset - // if there was an error. This is to prevent errors from accidentally being - // missed. - assert(status.succeeded() && "listener state was not checked"); -} - -DiagnosedSilenceableFailure -transform::ErrorCheckingTrackingListener::checkAndResetError() { - DiagnosedSilenceableFailure s = std::move(status); - status = DiagnosedSilenceableFailure::success(); - errorCounter = 0; - return s; -} - -bool transform::ErrorCheckingTrackingListener::failed() const { - return !status.succeeded(); -} - -void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( - Operation *op, ValueRange values) { - if (status.succeeded()) { - status = emitSilenceableFailure( - getTransformOp(), "tracking listener failed to find replacement op"); - } - - status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; - for (auto &&[index, value] : llvm::enumerate(values)) - status.attachNote(value.getLoc()) - << "[" << errorCounter << "] replacement value " << index; - - ++errorCounter; -} - //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===//