diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -525,7 +525,7 @@ /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx); + explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -932,14 +932,20 @@ /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops'. -LogicalResult applyAnalysisConversion(ArrayRef ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); -LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); +/// the regions nested within 'ops'. There's an additional argument +/// `notifyCallback` which is used for collecting match failure diagnostics +/// generated during the conversion. Diagnostics are only reported to this +/// callback may only be available in debug mode. +LogicalResult applyAnalysisConversion( + ArrayRef ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); +LogicalResult applyAnalysisConversion( + Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -851,8 +851,9 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl { - ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations) {} + explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) + : argConverter(rewriter, unresolvedMaterializations), + notifyCallback(nullptr) {} /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. @@ -1004,6 +1005,9 @@ /// active. TypeConverter *currentTypeConverter = nullptr; + /// This allows the user to collect the match failure message. + function_ref notifyCallback; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1475,6 +1479,8 @@ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; + if (notifyCallback) + notifyCallback(diag); }); return failure(); } @@ -1949,7 +1955,16 @@ // Functor that cleans up the rewriter state after a pattern failed to match. RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); + LLVM_DEBUG({ + logFailure(rewriterImpl.logger, "pattern failed to match"); + if (rewriterImpl.notifyCallback) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); + diag << "Failed to apply pattern \"" << pattern.getDebugName() + << "\" on op:\n" + << *op; + rewriterImpl.notifyCallback(diag); + } + }); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); }; @@ -2333,7 +2348,9 @@ : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops); + LogicalResult + convertOperations(ArrayRef ops, + function_ref notifyCallback = nullptr); private: /// Converts an operation with the given rewriter. @@ -2410,7 +2427,9 @@ return success(); } -LogicalResult OperationConverter::convertOperations(ArrayRef ops) { +LogicalResult OperationConverter::convertOperations( + ArrayRef ops, + function_ref notifyCallback) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -2428,6 +2447,8 @@ // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext()); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + rewriterImpl.notifyCallback = notifyCallback; + for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriterImpl.discardRewrites(), failure(); @@ -3275,15 +3296,17 @@ mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops); + return opConverter.convertOperations(ops, notifyCallback); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps); + convertedOps, notifyCallback); }