diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -141,6 +141,10 @@ set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.") +configure_file( + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake + ${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h) + #------------------------------------------------------------------------------- # Python Bindings Configuration # Requires: diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Config/mlir-config.h.cmake @@ -0,0 +1,20 @@ +//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* This file enumerates variables from the MLIR configuration so that they + can be in exported headers and won't override package specific directives. + This is a C header that can be included in the mlir-c headers. */ + +#ifndef MLIR_CONFIG_H +#define MLIR_CONFIG_H + +/* Enable expensive checks to detect invalid pattern API usage. NDEBUG is + required. Running with ASAN is recommended for easier debugging. */ +#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + +#endif diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Config/mlir-config.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -121,6 +122,16 @@ /// The low-level pattern applicator. PatternApplicator matcher; + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// Operation finger prints to detect invalid pattern API usage. IR is checked + /// against these finger prints after pattern application to detect cases + /// where IR was modified directly, bypassing the rewriter API. + DenseMap fingerprints; + + /// Top-level operation of the current greedy rewrite. + Operation *fingerprintTopLevel = nullptr; +#endif // NDEBUG }; } // namespace @@ -225,15 +236,57 @@ return success(); }; +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Compute finger prints of all ops in the current rewrite scope. + auto clearFingerprints = + llvm::make_scope_exit([&]() { fingerprints.clear(); }); + fingerprintTopLevel = config.scope ? config.scope->getParentOp() : op; + fingerprintTopLevel->walk( + [&](Operation *op) { fingerprints.try_emplace(op, op); }); + + // Finger print before applying a pattern. + OperationFingerPrint &beforeFingerPrint = + fingerprints.find(fingerprintTopLevel)->second; + + // Apply a pattern. + LogicalResult matchResult = + matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + + // Finger print after applying a pattern. + OperationFingerPrint afterFingerPrint(fingerprintTopLevel); + if (succeeded(matchResult)) { + // Pattern application success => IR must have changed. + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + assert(beforeFingerPrint != afterFingerPrint && + "pattern returned success but IR did not change"); + fingerprintTopLevel->walk([&](Operation *op) { + if (op == fingerprintTopLevel) + return; + auto it = fingerprints.find(op); + if (it != fingerprints.end()) { + // Check finger prints of nested ops. + assert(it->second == OperationFingerPrint(op) && + "operation finger print changed"); + } + }); + } else { + // Pattern application failure => IR must not have changed. + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + assert(beforeFingerPrint == afterFingerPrint && + "pattern returned failure but IR did change"); + } +#else LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) + if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else + } else { LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); -#endif +#endif // NDEBUG if (succeeded(matchResult)) { changed = true; @@ -245,6 +298,7 @@ } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + assert(op && "expected valid op"); // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; Region *region = nullptr; @@ -303,6 +357,13 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); +#ifndef NDEBUG + Operation *fpOp = op->getParentOp(); + while (fpOp && fpOp != fingerprintTopLevel) { + fingerprints.erase(fpOp); + fpOp = fpOp->getParentOp(); + } +#endif // NDEBUG if (config.listener) config.listener->notifyOperationInserted(op); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) @@ -317,6 +378,13 @@ }); if (config.listener) config.listener->notifyOperationModified(op); +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *fpOp = op; + while (fpOp && fpOp != fingerprintTopLevel) { + fingerprints.erase(fpOp); + fpOp = fpOp->getParentOp(); + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS addToWorklist(op); } @@ -350,6 +418,13 @@ if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + while (op && op != fingerprintTopLevel) { + fingerprints.erase(op); + op = op->getParentOp(); + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS } void GreedyPatternRewriteDriver::notifyOperationReplaced(