diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -497,7 +497,9 @@ /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx); + explicit ConversionPatternRewriter( + MLIRContext *ctx, + function_ref notifyCallback = nullptr); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -905,13 +907,16 @@ /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of /// the regions nested within 'ops'. -LogicalResult applyAnalysisConversion(ArrayRef ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); -LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps); +LogicalResult applyAnalysisConversion( + ArrayRef ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); +LogicalResult applyAnalysisConversion( + Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + DenseSet &convertedOps, + function_ref notifyCallback = nullptr); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -850,8 +850,11 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl { - ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter, unresolvedMaterializations) {} + explicit ConversionPatternRewriterImpl( + PatternRewriter &rewriter, + function_ref notifyCallback = nullptr) + : argConverter(rewriter, unresolvedMaterializations), + notifyCallback(notifyCallback) {} /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. @@ -1003,6 +1006,9 @@ /// active. TypeConverter *currentTypeConverter = nullptr; + /// This allows the user to collect the match failure message. + function_ref notifyCallback; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1470,11 +1476,12 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( Location loc, function_ref reasonCallback) { - LLVM_DEBUG({ + if (llvm::DebugFlag || notifyCallback) { Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; - }); + notifyCallback(diag); + } return failure(); } @@ -1482,9 +1489,10 @@ // ConversionPatternRewriter //===----------------------------------------------------------------------===// -ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) +ConversionPatternRewriter::ConversionPatternRewriter( + MLIRContext *ctx, function_ref notifyCallback) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this)) {} + impl(new detail::ConversionPatternRewriterImpl(*this, notifyCallback)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} void ConversionPatternRewriter::replaceOpWithIf( @@ -2332,7 +2340,9 @@ : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops); + LogicalResult + convertOperations(ArrayRef ops, + function_ref notifyCallback = nullptr); private: /// Converts an operation with the given rewriter. @@ -2409,7 +2419,9 @@ return success(); } -LogicalResult OperationConverter::convertOperations(ArrayRef ops) { +LogicalResult OperationConverter::convertOperations( + ArrayRef ops, + function_ref notifyCallback) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -2425,7 +2437,7 @@ } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext()); + ConversionPatternRewriter rewriter(ops.front()->getContext(), notifyCallback); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); for (auto *op : toConvert) if (failed(convert(rewriter, op))) @@ -3270,15 +3282,17 @@ mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops); + return opConverter.convertOperations(ops, notifyCallback); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet &convertedOps) { + DenseSet &convertedOps, + function_ref notifyCallback) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps); + convertedOps, notifyCallback); }