diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -334,6 +334,23 @@ finalizeRootUpdate(root); } + /// Notify the pattern rewriter that the pattern is failing to match the given + /// operation, and provide a callback to populate a diagnostic with the reason + /// why the failure occurred. This method allows for derived rewriters to + /// optionally hook into the reason why a pattern failed, and display it to + /// users. + virtual LogicalResult + notifyMatchFailure(Operation *op, + function_ref reasonCallback) { + return failure(); + } + LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) { + return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; }); + } + LogicalResult notifyMatchFailure(Operation *op, const char *msg) { + return notifyMatchFailure(op, Twine(msg)); + } + protected: explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {} virtual ~PatternRewriter(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -379,6 +379,12 @@ /// PatternRewriter hook for updating the root operation in-place. void cancelRootUpdate(Operation *op) override; + /// PatternRewriter hook for notifying match failure reasons. + LogicalResult + notifyMatchFailure(Operation *op, + function_ref reasonCallback) override; + using PatternRewriter::notifyMatchFailure; + /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &getImpl(); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -989,6 +989,17 @@ rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); } +/// PatternRewriter hook for notifying match failure reasons. +LogicalResult ConversionPatternRewriter::notifyMatchFailure( + Operation *op, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Error); + reasonCallback(diag); + impl->logger.startLine() << "** Failure : " << diag.str() << "\n"; + }); + return failure(); +} + /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -272,7 +272,7 @@ ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. if (!Type(*op->result_type_begin()).isF32()) - return matchFailure(); + return rewriter.notifyMatchFailure(op, "expected single f32 operand"); rewriter.replaceOpWithNewOp(op, rewriter.getF64Type()); return matchSuccess(); }