diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -207,39 +207,40 @@ // notified of any necessary changes, so there is nothing else to do // here. #ifndef NDEBUG - auto canApply = [&](const Pattern &pattern) { - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; - llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); - logger.getOStream() << ")' {\n"; - logger.indent(); - }); - return true; - }; - auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); - }; - auto onSuccess = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); - }; - - LogicalResult matchResult = - matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) - LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else - LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + auto canApply = [&](const Pattern &pattern) { + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" + << op->getName() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); + logger.getOStream() << ")' {\n"; + logger.indent(); + }); + return true; + }; + auto onFailure = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("failure", "pattern failed to match")); + }; + auto onSuccess = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("success", "pattern applied successfully")); + return success(); + }; #else - LogicalResult matchResult = matcher.matchAndRewrite(op, *this); + function_ref canApply = {}; + function_ref onFailure = {}; + function_ref onSuccess = {}; #endif - if (succeeded(matchResult)) { - changed = true; - ++numRewrites; - } + LogicalResult matchResult = + matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + + if (succeeded(matchResult)) { + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + changed = true; + ++numRewrites; + } else { + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + } } return changed;