diff --git a/llvm/include/llvm/Analysis/ConstraintSystem.h b/llvm/include/llvm/Analysis/ConstraintSystem.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Analysis/ConstraintSystem.h @@ -0,0 +1,57 @@ +//===- ConstraintSystem.h - A system of linear constraints. --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_CONSTRAINTSYSTEM_H +#define LLVM_ANALYSIS_CONSTRAINTSYSTEM_H + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace llvm { + +class ConstraintSystem { + /// Current linear constraints in the system. + /// An entry of the form c0, c1, ... cn represents the following constraint: + /// c0 >= v0 * c1 + .... + v{n-1} * cn + SmallVector, 4> Constraints; + + /// Current greatest common divisor for all coefficients in the system. + uint32_t GCD = 1; + + // Eliminate constraints from the system using Fourier–Motzkin elimination. + bool eliminateUsingFM(); + + /// Print the constraints in the system, using \p Names as variable names. + void dump(ArrayRef Names) const; + + /// Print the constraints in the system, using x0...xn as variable names. + void dump() const; + + /// Returns true if there may be a solution for the constraints in the system. + bool mayHaveSolutionImpl(); + +public: + void addVariableRow(const SmallVector &R) { + assert(Constraints.empty() || R.size() == Constraints.back().size()); + for (const auto &C : R) { + auto A = std::abs(C); + GCD = APIntOps::GreatestCommonDivisor({32, (uint32_t)A}, {32, GCD}) + .getZExtValue(); + } + Constraints.push_back(R); + } + + /// Returns true if there may be a solution for the constraints in the system. + bool mayHaveSolution(); +}; +} // namespace llvm + +#endif // LLVM_ANALYSIS_CONSTRAINTSYSTEM_H diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -39,6 +39,7 @@ CodeMetrics.cpp ConstantFolding.cpp DDG.cpp + ConstraintSystem.cpp Delinearization.cpp DemandedBits.cpp DependenceAnalysis.cpp diff --git a/llvm/lib/Analysis/ConstraintSystem.cpp b/llvm/lib/Analysis/ConstraintSystem.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Analysis/ConstraintSystem.cpp @@ -0,0 +1,141 @@ +//===- ConstraintSytem.cpp - A system of linear constraints. ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ConstraintSystem.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Debug.h" + +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "constraint-system" + +bool ConstraintSystem::eliminateUsingFM() { + // Implementation of Fourier–Motzkin elimination, with some tricks from the + // paper Pugh, William. "The Omega test: a fast and practical integer + // programming algorithm for dependence + // analysis." + // Supercomputing'91: Proceedings of the 1991 ACM/ + // IEEE conference on Supercomputing. IEEE, 1991. + assert(!Constraints.empty() && + "should only be called for non-empty constraint systems"); + unsigned NumVariables = Constraints[0].size(); + SmallVector, 4> NewSystem; + + unsigned NumConstraints = Constraints.size(); + uint32_t NewGCD = 1; + // FIXME do not use copy + for (unsigned R1 = 0; R1 < NumConstraints; R1++) { + if (Constraints[R1][1] == 0) { + SmallVector NR; + NR.push_back(Constraints[R1][0]); + for (unsigned i = 2; i < NumVariables; i++) { + NR.push_back(Constraints[R1][i]); + } + NewSystem.push_back(std::move(NR)); + continue; + } + + // FIXME do not use copy + bool EliminatedInRow = false; + for (unsigned R2 = R1 + 1; R2 < NumConstraints; R2++) { + if (R1 == R2) + continue; + + // FIXME: can we do better than just dropping things here? + if (Constraints[R2][1] == 0) + continue; + + if ((Constraints[R1][1] < 0 && Constraints[R2][1] < 0) || + (Constraints[R1][1] > 0 && Constraints[R2][1] > 0)) + continue; + + unsigned LowerR = R1; + unsigned UpperR = R2; + if (Constraints[UpperR][1] < 0) + std::swap(LowerR, UpperR); + + SmallVector NR; + for (unsigned I = 0; I < NumVariables; I++) { + if (I == 1) + continue; + + int64_t M1, M2, N; + if (__builtin_mul_overflow(Constraints[UpperR][I], + ((-1) * Constraints[LowerR][1] / GCD), &M1)) + return false; + if (__builtin_mul_overflow(Constraints[LowerR][I], + (Constraints[UpperR][1] / GCD), &M2)) + return false; + if (__builtin_add_overflow(M1, M2, &N)) + return false; + NR.push_back(N); + + NewGCD = APIntOps::GreatestCommonDivisor({32, (uint32_t)NR.back()}, + {32, NewGCD}) + .getZExtValue(); + } + NewSystem.push_back(std::move(NR)); + EliminatedInRow = true; + } + } + Constraints = std::move(NewSystem); + GCD = NewGCD; + + return true; +} + +bool ConstraintSystem::mayHaveSolutionImpl() { + while (!Constraints.empty() && Constraints[0].size() > 1) { + if (!eliminateUsingFM()) + return true; + } + + if (Constraints.empty() || Constraints[0].size() > 1) + return true; + + return all_of(Constraints, [](auto &R) { return R[0] >= 0; }); +} + +void ConstraintSystem::dump(ArrayRef Names) const { + if (Constraints.empty()) + return; + + for (auto &Row : Constraints) { + SmallVector Parts; + for (unsigned I = 1, S = Row.size(); I < S; ++I) { + if (Row[I] == 0) + continue; + std::string Coefficient = ""; + if (Row[I] != 1) + Coefficient = std::to_string(Row[I]) + " * "; + Parts.push_back(Coefficient + Names[I - 1]); + } + assert(!Parts.empty() && "need to have at least some parts"); + LLVM_DEBUG(dbgs() << join(Parts, std::string(" + ")) + << " <= " << std::to_string(Row[0]) << "\n"); + } +} + +void ConstraintSystem::dump() const { + SmallVector Names; + for (unsigned i = 1; i < Constraints.back().size(); ++i) + Names.push_back("x" + std::to_string(i)); + LLVM_DEBUG(dbgs() << "---\n"); + dump(Names); +} + +bool ConstraintSystem::mayHaveSolution() { + dump(); + bool HasSolution = mayHaveSolutionImpl(); + LLVM_DEBUG(dbgs() << (HasSolution ? "sat" : "unsat") << "\n"); + return HasSolution; +} diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -23,6 +23,7 @@ CaptureTrackingTest.cpp CFGTest.cpp CGSCCPassManagerTest.cpp + ConstraintSystemTest.cpp DDGTest.cpp DivergenceAnalysisTest.cpp DomTreeUpdaterTest.cpp diff --git a/llvm/unittests/Analysis/ConstraintSystemTest.cpp b/llvm/unittests/Analysis/ConstraintSystemTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Analysis/ConstraintSystemTest.cpp @@ -0,0 +1,82 @@ +//===--- ConstraintSystemTests.cpp ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ConstraintSystem.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +TEST(ConstraintSloverTest, TestSolutionChecks) { + { + ConstraintSystem CS; + // x + y <= 10, x >= 5, y >= 6, x <= 10, y <= 10 + CS.addVariableRow({10, 1, 1}); + CS.addVariableRow({-5, -1, 0}); + CS.addVariableRow({-6, 0, -1}); + CS.addVariableRow({10, 1, 0}); + CS.addVariableRow({10, 0, 1}); + + EXPECT_FALSE(CS.mayHaveSolution()); + } + + { + ConstraintSystem CS; + // x + y <= 10, x >= 2, y >= 3, x <= 10, y <= 10 + CS.addVariableRow({10, 1, 1}); + CS.addVariableRow({-2, -1, 0}); + CS.addVariableRow({-3, 0, -1}); + CS.addVariableRow({10, 1, 0}); + CS.addVariableRow({10, 0, 1}); + + EXPECT_TRUE(CS.mayHaveSolution()); + } + + { + ConstraintSystem CS; + // x + y <= 10, 10 >= x, 10 >= y; does not have a solution. + CS.addVariableRow({10, 1, 1}); + CS.addVariableRow({-10, -1, 0}); + CS.addVariableRow({-10, 0, -1}); + + EXPECT_FALSE(CS.mayHaveSolution()); + } + + { + ConstraintSystem CS; + // x + y >= 20, 10 >= x, 10 >= y; does HAVE a solution. + CS.addVariableRow({-20, -1, -1}); + CS.addVariableRow({-10, -1, 0}); + CS.addVariableRow({-10, 0, -1}); + + EXPECT_TRUE(CS.mayHaveSolution()); + } + + { + ConstraintSystem CS; + + // 2x + y + 3z <= 10, 2x + y >= 10, y >= 1 + CS.addVariableRow({10, 2, 1, 3}); + CS.addVariableRow({-10, -2, -1, 0}); + CS.addVariableRow({-1, 0, 0, -1}); + + EXPECT_FALSE(CS.mayHaveSolution()); + } + + { + ConstraintSystem CS; + + // 2x + y + 3z <= 10, 2x + y >= 10 + CS.addVariableRow({10, 2, 1, 3}); + CS.addVariableRow({-10, -2, -1, 0}); + + EXPECT_TRUE(CS.mayHaveSolution()); + } +} +} // namespace diff --git a/llvm/utils/convert-constraint-log-to-z3.py b/llvm/utils/convert-constraint-log-to-z3.py new file mode 100755 --- /dev/null +++ b/llvm/utils/convert-constraint-log-to-z3.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +""" +Helper script to convert the log generated by '-debug-only=constraint-system' +to a Python script that uses Z3 to verify the decisions using Z3's Python API. + +Example usage: + +> cat path/to/file.log +--- +x6 + -1 * x7 <= -1 +x6 + -1 * x7 <= -2 +sat + +> ./convert-constraint-log-to-z3.py path/to/file.log > check.py && python ./check.py + +> cat check.py + from z3 import * +x3 = Int("x3") +x1 = Int("x1") +x2 = Int("x2") +s = Solver() +s.add(x1 + -1 * x2 <= 0) +s.add(x2 + -1 * x3 <= 0) +s.add(-1 * x1 + x3 <= -1) +assert(s.check() == unsat) +print('all checks passed') +""" + + +import argparse +import re + + +def main(): + parser = argparse.ArgumentParser( + description='Convert constraint log to script to verify using Z3.') + parser.add_argument('log_file', metavar='log', type=str, + help='constraint-system log file') + args = parser.parse_args() + + content = '' + with open(args.log_file, 'rt') as f: + content = f.read() + + groups = content.split('---') + var_re = re.compile('x\d+') + + print('from z3 import *') + for group in groups: + constraints = [g.strip() for g in group.split('\n') if g.strip() != ''] + variables = set() + for c in constraints[:-1]: + for m in var_re.finditer(c): + variables.add(m.group()) + if len(variables) == 0: + continue + for v in variables: + print('{} = Int("{}")'.format(v, v)) + print('s = Solver()') + for c in constraints[:-1]: + print('s.add({})'.format(c)) + expected = constraints[-1].strip() + print('assert(s.check() == {})'.format(expected)) + print('print("all checks passed")') + + +if __name__ == '__main__': + main()