diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -141,6 +141,10 @@ set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.") +configure_file( + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake + ${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h) + #------------------------------------------------------------------------------- # Python Bindings Configuration # Requires: diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Config/mlir-config.h.cmake @@ -0,0 +1,20 @@ +//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* This file enumerates variables from the MLIR configuration so that they + can be in exported headers and won't override package specific directives. + This is a C header that can be included in the mlir-c headers. */ + +#ifndef MLIR_CONFIG_H +#define MLIR_CONFIG_H + +/* Enable expensive checks to detect invalid pattern API usage. NDEBUG is + required. Running with ASAN is recommended for easier debugging. */ +#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + +#endif diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Config/mlir-config.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -121,6 +122,18 @@ /// The low-level pattern applicator. PatternApplicator matcher; + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// Operation finger prints to detect invalid pattern API usage. IR is checked + /// against these finger prints after pattern application to detect cases + /// where IR was modified directly, bypassing the rewriter API. + DenseMap fingerprints; + + /// Top-level operation of the current greedy rewrite. + Operation *fingerprintTopLevel = nullptr; + + void invalidateFingerPrint(Operation *op); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS }; } // namespace @@ -225,15 +238,59 @@ return success(); }; +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Compute finger prints of all ops in the current rewrite scope. + auto clearFingerprints = + llvm::make_scope_exit([&]() { fingerprints.clear(); }); + fingerprintTopLevel = config.scope ? config.scope->getParentOp() : op; + fingerprintTopLevel->walk( + [&](Operation *op) { fingerprints.try_emplace(op, op); }); + + // Finger print before applying a pattern. + OperationFingerPrint &beforeFingerPrint = + fingerprints.find(fingerprintTopLevel)->second; + + // Apply a pattern. + LogicalResult matchResult = + matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + + // Finger print after applying a pattern. + OperationFingerPrint afterFingerPrint(fingerprintTopLevel); + + if (succeeded(matchResult)) { + // Pattern application success => IR must have changed. + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + assert(beforeFingerPrint != afterFingerPrint && + "pattern returned success but IR did not change"); + for (const auto &it : fingerprints) { + // Skip top-level op, its finger print is never invalidated. + if (it.first == fingerprintTopLevel) + continue; + // Note: Finger print computation may crash when an op was erased + // without notifying the rewriter. (Run with ASAN.) It does not crash + // if a new op was created at the same memory location. (But then the + // finger print should have changed.) + assert(it.second == OperationFingerPrint(it.first) && + "operation finger print changed"); + } + } else { + // Pattern application failure => IR must not have changed. + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + assert(beforeFingerPrint == afterFingerPrint && + "pattern returned failure but IR did change"); + } +#else LogicalResult matchResult = matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) + if (succeeded(matchResult)) { LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else + } else { LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS #else LogicalResult matchResult = matcher.matchAndRewrite(op, *this); -#endif +#endif // NDEBUG if (succeeded(matchResult)) { changed = true; @@ -244,7 +301,18 @@ return changed; } +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS +void GreedyPatternRewriteDriver::invalidateFingerPrint(Operation *op) { + // Invalidate all finger prints until the top level. + while (op && op != fingerprintTopLevel) { + fingerprints.erase(op); + op = op->getParentOp(); + } +} +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + assert(op && "expected valid op"); // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; Region *region = nullptr; @@ -303,6 +371,11 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + invalidateFingerPrint(op->getParentOp()); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (config.listener) config.listener->notifyOperationInserted(op); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) @@ -317,6 +390,11 @@ }); if (config.listener) config.listener->notifyOperationModified(op); + +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + invalidateFingerPrint(op); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + addToWorklist(op); } @@ -346,6 +424,9 @@ op->walk([this](Operation *operation) { removeFromWorklist(operation); folder.notifyRemoval(operation); +#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + invalidateFingerPrint(operation); +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS }); if (config.strictMode != GreedyRewriteStrictness::AnyOp)