diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -497,7 +497,7 @@ /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx); + explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -904,14 +904,20 @@ /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops'. -LogicalResult applyAnalysisConversion(ArrayRef ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); -LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); +/// the regions nested within 'ops'. There's an additional argument +/// `notifyCallback` which is used for collecting the conversion diagnostic +/// generated during the conversion. Most diagnostics reported to this callback +/// may only be available in debug mode. +LogicalResult applyAnalysisConversion( + ArrayRef ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); +LogicalResult applyAnalysisConversion( + Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -850,8 +850,9 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl { - ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations) {} + explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) + : argConverter(rewriter, unresolvedMaterializations), + notifyCallback(nullptr) {} /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. @@ -953,6 +954,16 @@ notifyMatchFailure(Location loc, function_ref reasonCallback); + /// Registers a callback to get notified when there's any pattern match + /// failure. + void setNotifyCallback(function_ref callback); + + /// Check if there's a registered callback. + bool hasNotifyCallback() const; + + /// Invoke the callback with the given diagnostic. + void notifyMatchResult(Diagnostic &diag); + //===--------------------------------------------------------------------===// // State //===--------------------------------------------------------------------===// @@ -1003,6 +1014,9 @@ /// active. TypeConverter *currentTypeConverter = nullptr; + /// This allows the user to collect the match failure message. + function_ref notifyCallback; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1474,10 +1488,25 @@ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; + if (hasNotifyCallback()) + notifyMatchResult(diag); }); return failure(); } +void ConversionPatternRewriterImpl::setNotifyCallback( + function_ref callback) { + notifyCallback = callback; +} + +bool ConversionPatternRewriterImpl::hasNotifyCallback() const { + return notifyCallback ? true : false; +} + +void ConversionPatternRewriterImpl::notifyMatchResult(Diagnostic &diag) { + notifyCallback(diag); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriter //===----------------------------------------------------------------------===// @@ -1485,6 +1514,7 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(*this)) {} + ConversionPatternRewriter::~ConversionPatternRewriter() {} void ConversionPatternRewriter::replaceOpWithIf( @@ -1948,7 +1978,15 @@ // Functor that cleans up the rewriter state after a pattern failed to match. RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); + LLVM_DEBUG({ + logFailure(rewriterImpl.logger, "pattern failed to match"); + if (rewriterImpl.hasNotifyCallback()) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + diag << "Failed to apply pattern \"" << pattern.getDebugName() + << "\" on " << *op; + rewriterImpl.notifyMatchResult(diag); + } + }); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); }; @@ -2332,7 +2370,9 @@ : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops); + LogicalResult + convertOperations(ArrayRef ops, + function_ref notifyCallback = nullptr); private: /// Converts an operation with the given rewriter. @@ -2409,7 +2449,9 @@ return success(); } -LogicalResult OperationConverter::convertOperations(ArrayRef ops) { +LogicalResult OperationConverter::convertOperations( + ArrayRef ops, + function_ref notifyCallback) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -2427,6 +2469,8 @@ // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext()); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + rewriterImpl.setNotifyCallback(notifyCallback); + for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriterImpl.discardRewrites(), failure(); @@ -3270,15 +3314,17 @@ mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops); + return opConverter.convertOperations(ops, notifyCallback); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps); + convertedOps, notifyCallback); }