Index: include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h =================================================================== --- /dev/null +++ include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h @@ -0,0 +1,237 @@ +//== SMTSolver.h ------------------------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a SMT generic Solver API, which will be the base class +// for every SMT solver specific class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H +#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H + +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" + +namespace clang { +namespace ento { + +class SMTSolver { +public: + SMTSolver() = default; + virtual ~SMTSolver() = default; + + LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } + + // Return a boolean sort. + virtual SMTSort getBoolSort() = 0; + + // Return an appropriate bitvector sort for the given bitwidth. + virtual SMTSort getBitvectorSort(const unsigned BitWidth) = 0; + + // Return a floating-point sort of width 16 + virtual SMTSort getFloat16Sort() = 0; + + // Return a floating-point sort of width 32 + virtual SMTSort getFloat32Sort() = 0; + + // Return a floating-point sort of width 64 + virtual SMTSort getFloat64Sort() = 0; + + // Return a floating-point sort of width 128 + virtual SMTSort getFloat128Sort() = 0; + + // Return an appropriate sort for the given AST. + virtual SMTSort getSort(const SMTExpr &AST) = 0; + + /// Given a constraint, add it to the solver + virtual void addConstraint(const SMTExpr &Exp) = 0; + + /// Create a bitvector addition operation + virtual SMTExpr mkBVAdd(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector subtraction operation + virtual SMTExpr mkBVSub(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector multiplication operation + virtual SMTExpr mkBVMul(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed modulus operation + virtual SMTExpr mkBVSRem(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned modulus operation + virtual SMTExpr mkBVURem(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed division operation + virtual SMTExpr mkBVSDiv(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned division operation + virtual SMTExpr mkBVUDiv(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector logical shift left operation + virtual SMTExpr mkBVShl(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector arithmetic shift right operation + virtual SMTExpr mkBVAshr(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector logical shift right operation + virtual SMTExpr mkBVLshr(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector negation operation + virtual SMTExpr mkBVNeg(const SMTExpr &Exp) = 0; + + /// Create a bitvector not operation + virtual SMTExpr mkBVNot(const SMTExpr &Exp) = 0; + + /// Create a bitvector xor operation + virtual SMTExpr mkBVXor(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector or operation + virtual SMTExpr mkBVOr(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector and operation + virtual SMTExpr mkBVAnd(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned less-than operation + virtual SMTExpr mkBVUlt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed less-than operation + virtual SMTExpr mkBVSlt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned greater-than operation + virtual SMTExpr mkBVUgt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed greater-than operation + virtual SMTExpr mkBVSgt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned less-equal-than operation + virtual SMTExpr mkBVUle(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed less-equal-than operation + virtual SMTExpr mkBVSle(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector unsigned greater-equal-than operation + virtual SMTExpr mkBVUge(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a bitvector signed greater-equal-than operation + virtual SMTExpr mkBVSge(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + /// Create a boolean not operation + virtual SMTExpr mkNot(const SMTExpr &Exp) = 0; + + /// Create a bitvector equality operation + virtual SMTExpr mkEqual(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkAnd(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkOr(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkIte(const SMTExpr &Cond, const SMTExpr &T, + const SMTExpr &F) = 0; + + virtual SMTExpr mkSignExt(unsigned i, const SMTExpr &Exp) = 0; + + virtual SMTExpr mkZeroExt(unsigned i, const SMTExpr &Exp) = 0; + + virtual SMTExpr mkExtract(unsigned High, unsigned Low, + const SMTExpr &Exp) = 0; + + virtual SMTExpr mkConcat(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPNeg(const SMTExpr &Exp) = 0; + + virtual SMTExpr mkFPIsInfinite(const SMTExpr &Exp) = 0; + + virtual SMTExpr mkFPIsNaN(const SMTExpr &Exp) = 0; + + virtual SMTExpr mkFPIsNormal(const SMTExpr &Exp) = 0; + + virtual SMTExpr mkFPIsZero(const SMTExpr &Exp) = 0; + + virtual SMTExpr mkFPMul(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPDiv(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPRem(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPAdd(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPSub(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPLt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPGt(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPLe(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPGe(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPEqual(const SMTExpr &LHS, const SMTExpr &RHS) = 0; + + virtual SMTExpr mkFPtoFP(const SMTExpr &From, const SMTSort &To) = 0; + + virtual SMTExpr mkFPtoSBV(const SMTExpr &From, const SMTSort &To) = 0; + + virtual SMTExpr mkFPtoUBV(const SMTExpr &From, const SMTSort &To) = 0; + + virtual SMTExpr mkSBVtoFP(const SMTExpr &From, unsigned ToWidth) = 0; + + virtual SMTExpr mkUBVtoFP(const SMTExpr &From, unsigned ToWidth) = 0; + + virtual SMTExpr mkSymbol(const char *Name, SMTSort Sort) = 0; + + // Return an appropriate floating-point rounding mode. + virtual SMTExpr getFloatRoundingMode() = 0; + + virtual const llvm::APSInt getBitvector(const SMTExpr &Exp) = 0; + + virtual bool getBoolean(const SMTExpr &Exp) = 0; + + /// Construct a const SMTExpr &From a boolean. + virtual SMTExpr mkBoolean(const bool b) = 0; + + /// Construct a const SMTExpr &From a finite APFloat. + virtual SMTExpr mkFloat(const llvm::APFloat Float) = 0; + + /// Construct a const SMTExpr &From an APSInt. + virtual SMTExpr mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0; + + SMTExpr mkBitvector(const llvm::APSInt Int) { + return mkBitvector(Int, Int.getBitWidth()); + } + + /// Given an expression, extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExpr &Exp, llvm::APSInt &Int) = 0; + + /// Given an expression extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExpr &Exp, llvm::APFloat &Float) = 0; + + /// Construct a const SMTExpr &From a SymbolData, given a SMT_context. + virtual SMTExpr fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) = 0; + + /// Check if the constraints are satisfiable + virtual bool check() const = 0; + + /// Push the current solver state + virtual void push() = 0; + + /// Pop the previous solver state + virtual void pop(unsigned NumStates = 1) = 0; + + /// Reset the solver and remove all constraints. + virtual void reset() const = 0; + + virtual void print(raw_ostream &OS) const = 0; +}; + +} // namespace ento +} // namespace clang + +#endif Index: lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp =================================================================== --- lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp +++ lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp @@ -13,6 +13,7 @@ #include "clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" #include "clang/Config/config.h" @@ -49,7 +50,9 @@ namespace { class Z3Expr; +class Z3Sort; static const Z3Expr &toZ3Expr(const SMTExpr &E); +static const Z3Sort &toZ3Sort(const SMTSort &S); class Z3Config { friend class Z3Context; @@ -171,6 +174,10 @@ } }; // end class Z3Sort +static const Z3Sort &toZ3Sort(const SMTSort &S) { + return static_cast(S); +} + class Z3Expr : public SMTExpr { friend class Z3Solver; @@ -314,25 +321,27 @@ llvm::APFloat::semanticsSizeInBits(RHS)); } -class Z3Solver { +class Z3Solver : public SMTSolver { friend class Z3ConstraintManager; Z3Context Context; Z3_solver Solver; - Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) { + Z3Solver() : SMTSolver(), Solver(Z3_mk_simple_solver(Context.Context)) { Z3_solver_inc_ref(Context.Context, Solver); } public: /// Override implicit copy constructor for correct reference counting. - Z3Solver(const Z3Solver &Copy) : Context(Copy.Context), Solver(Copy.Solver) { + Z3Solver(const Z3Solver &Copy) + : SMTSolver(), Context(Copy.Context), Solver(Copy.Solver) { Z3_solver_inc_ref(Context.Context, Solver); } /// Provide move constructor - Z3Solver(Z3Solver &&Move) : Context(Move.Context), Solver(nullptr) { + Z3Solver(Z3Solver &&Move) + : SMTSolver(), Context(Move.Context), Solver(nullptr) { *this = std::move(Move); } @@ -352,47 +361,36 @@ Z3_solver_dec_ref(Context.Context, Solver); } - /// Given a constraint, add it to the solver - void addConstraint(const Z3Expr &Exp) { - Z3_solver_assert(Context.Context, Solver, Exp.AST); + void addConstraint(const SMTExpr &Exp) override { + Z3_solver_assert(Context.Context, Solver, toZ3Expr(Exp).AST); } - // Return a boolean sort. - Z3Sort getBoolSort() { + SMTSort getBoolSort() override { return Z3Sort(Context, Z3_mk_bool_sort(Context.Context)); } - // Return an appropriate bitvector sort for the given bitwidth. - Z3Sort getBitvectorSort(unsigned BitWidth) { + SMTSort getBitvectorSort(unsigned BitWidth) override { return Z3Sort(Context, Z3_mk_bv_sort(Context.Context, BitWidth)); } // Return an appropriate floating-point sort for the given bitwidth. - Z3Sort getFloatSort(unsigned BitWidth) { - Z3_sort Sort; - + SMTSort getFloatSort(unsigned BitWidth) { switch (BitWidth) { - default: - llvm_unreachable("Unsupported floating-point bitwidth!"); - break; case 16: - Sort = Z3_mk_fpa_sort_16(Context.Context); - break; + return getFloat16Sort(); case 32: - Sort = Z3_mk_fpa_sort_32(Context.Context); - break; + return getFloat32Sort(); case 64: - Sort = Z3_mk_fpa_sort_64(Context.Context); - break; + return getFloat64Sort(); case 128: - Sort = Z3_mk_fpa_sort_128(Context.Context); - break; + return getFloat128Sort(); + default:; } - return Z3Sort(Context, Sort); + llvm_unreachable("Unsupported floating-point bitwidth!"); } // Return an appropriate sort, given a QualType - Z3Sort MkSort(const QualType &Ty, unsigned BitWidth) { + SMTSort mkSort(const QualType &Ty, unsigned BitWidth) { if (Ty->isBooleanType()) return getBoolSort(); @@ -402,9 +400,339 @@ return getBitvectorSort(BitWidth); } - // Return an appropriate sort for the given AST. - Z3Sort getSort(Z3_ast AST) { - return Z3Sort(Context, Z3_get_sort(Context.Context, AST)); + SMTSort getSort(const SMTExpr &Exp) override { + return Z3Sort(Context, Z3_get_sort(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTSort getFloat16Sort() override { + return Z3Sort(Context, Z3_mk_fpa_sort_16(Context.Context)); + } + + SMTSort getFloat32Sort() override { + return Z3Sort(Context, Z3_mk_fpa_sort_32(Context.Context)); + } + + SMTSort getFloat64Sort() override { + return Z3Sort(Context, Z3_mk_fpa_sort_64(Context.Context)); + } + + SMTSort getFloat128Sort() override { + return Z3Sort(Context, Z3_mk_fpa_sort_128(Context.Context)); + } + + SMTExpr mkBVNeg(const SMTExpr &Exp) override { + return Z3Expr(Context, Z3_mk_bvneg(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkBVNot(const SMTExpr &Exp) override { + return Z3Expr(Context, Z3_mk_bvnot(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkNot(const SMTExpr &Exp) override { + return Z3Expr(Context, Z3_mk_not(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkBVAdd(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvadd(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSub(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsub(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVMul(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvmul(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSRem(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsrem(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVURem(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvurem(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSDiv(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsdiv(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVUDiv(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvudiv(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVShl(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvshl(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVAshr(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvashr(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVLshr(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvlshr(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVXor(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvxor(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVOr(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvor(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVAnd(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvand(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVUlt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvult(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSlt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvslt(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVUgt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvugt(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSgt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsgt(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVUle(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvule(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSle(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsle(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVUge(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvuge(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkBVSge(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_bvsge(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkAnd(const SMTExpr &LHS, const SMTExpr &RHS) override { + Z3_ast Args[2] = {toZ3Expr(LHS).AST, toZ3Expr(RHS).AST}; + return Z3Expr(Context, Z3_mk_and(Context.Context, 2, Args)); + } + + SMTExpr mkOr(const SMTExpr &LHS, const SMTExpr &RHS) override { + Z3_ast Args[2] = {toZ3Expr(LHS).AST, toZ3Expr(RHS).AST}; + return Z3Expr(Context, Z3_mk_or(Context.Context, 2, Args)); + } + + SMTExpr mkEqual(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_eq(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPNeg(const SMTExpr &Exp) override { + return Z3Expr(Context, Z3_mk_fpa_neg(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkFPIsInfinite(const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_infinite(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkFPIsNaN(const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_nan(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkFPIsNormal(const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_normal(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkFPIsZero(const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_zero(Context.Context, toZ3Expr(Exp).AST)); + } + + SMTExpr mkFPMul(const SMTExpr &LHS, const SMTExpr &RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_mul(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST, toZ3Expr(RoundingMode).AST)); + } + + SMTExpr mkFPDiv(const SMTExpr &LHS, const SMTExpr &RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_div(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST, toZ3Expr(RoundingMode).AST)); + } + + SMTExpr mkFPRem(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_rem(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPAdd(const SMTExpr &LHS, const SMTExpr &RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_add(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST, toZ3Expr(RoundingMode).AST)); + } + + SMTExpr mkFPSub(const SMTExpr &LHS, const SMTExpr &RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_sub(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST, toZ3Expr(RoundingMode).AST)); + } + + SMTExpr mkFPLt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_lt(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPGt(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_gt(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPLe(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_leq(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPGe(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_geq(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPEqual(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_fpa_eq(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkIte(const SMTExpr &Cond, const SMTExpr &T, + const SMTExpr &F) override { + return Z3Expr(Context, Z3_mk_ite(Context.Context, toZ3Expr(Cond).AST, + toZ3Expr(T).AST, toZ3Expr(F).AST)); + } + + SMTExpr mkSignExt(unsigned i, const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_sign_ext(Context.Context, i, toZ3Expr(Exp).AST)); + } + + SMTExpr mkZeroExt(unsigned i, const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_zero_ext(Context.Context, i, toZ3Expr(Exp).AST)); + } + + SMTExpr mkExtract(unsigned High, unsigned Low, const SMTExpr &Exp) override { + return Z3Expr(Context, + Z3_mk_extract(Context.Context, High, Low, toZ3Expr(Exp).AST)); + } + + SMTExpr mkConcat(const SMTExpr &LHS, const SMTExpr &RHS) override { + return Z3Expr(Context, Z3_mk_concat(Context.Context, toZ3Expr(LHS).AST, + toZ3Expr(RHS).AST)); + } + + SMTExpr mkFPtoFP(const SMTExpr &From, const SMTSort &To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_to_fp_float( + Context.Context, toZ3Expr(RoundingMode).AST, + toZ3Expr(From).AST, toZ3Sort(To).Sort)); + } + + SMTExpr mkFPtoSBV(const SMTExpr &From, const SMTSort &To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_to_fp_signed( + Context.Context, toZ3Expr(RoundingMode).AST, + toZ3Expr(From).AST, toZ3Sort(To).Sort)); + } + + SMTExpr mkFPtoUBV(const SMTExpr &From, const SMTSort &To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_to_fp_unsigned( + Context.Context, toZ3Expr(RoundingMode).AST, + toZ3Expr(From).AST, toZ3Sort(To).Sort)); + } + + SMTExpr mkSBVtoFP(const SMTExpr &From, unsigned ToWidth) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_sbv(Context.Context, toZ3Expr(RoundingMode).AST, + toZ3Expr(From).AST, ToWidth)); + } + + SMTExpr mkUBVtoFP(const SMTExpr &From, unsigned ToWidth) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_ubv(Context.Context, toZ3Expr(RoundingMode).AST, + toZ3Expr(From).AST, ToWidth)); + } + + SMTExpr mkBoolean(const bool b) override { + return Z3Expr(Context, b ? Z3_mk_true(Context.Context) + : Z3_mk_false(Context.Context)); + } + + SMTExpr mkBitvector(const llvm::APSInt Int, unsigned BitWidth) override { + const SMTSort Sort = getBitvectorSort(BitWidth); + return Z3Expr(Context, + Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), + toZ3Sort(Sort).Sort)); + } + + SMTExpr mkFloat(const llvm::APFloat Float) override { + SMTSort Sort = + getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics())); + + llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false); + SMTExpr Z3Int = mkBitvector(Int, Int.getBitWidth()); + return Z3Expr(Context, + Z3_mk_fpa_to_fp_bv(Context.Context, toZ3Expr(Z3Int).AST, + toZ3Sort(Sort).Sort)); + } + + SMTExpr mkSymbol(const char *Name, SMTSort Sort) override { + Sort.dump(); + return Z3Expr(Context, + Z3_mk_const(Context.Context, + Z3_mk_string_symbol(Context.Context, Name), + toZ3Sort(Sort).Sort)); + } + + const llvm::APSInt getBitvector(const SMTExpr &Exp) override { + return llvm::APSInt( + Z3_get_numeral_string(Context.Context, toZ3Expr(Exp).AST)); + } + + bool getBoolean(const SMTExpr &Exp) override { + return Z3_get_bool_value(Context.Context, toZ3Expr(Exp).AST) == Z3_L_TRUE; } /// Given a program state, construct the logical conjunction and add it to @@ -427,7 +755,7 @@ } // Return an appropriate floating-point rounding mode. - Z3Expr getFloatRoundingMode() { + SMTExpr getFloatRoundingMode() override { // TODO: Don't assume nearest ties to even rounding mode return Z3Expr(Context, Z3_mk_fpa_rne(Context.Context)); } @@ -505,8 +833,7 @@ const Z3Expr &RHS, bool isSigned) { Z3_ast AST; - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); + assert(getSort(LHS) == getSort(RHS) && "AST's must have the same sort!"); switch (Op) { default: @@ -634,8 +961,7 @@ const Z3Expr &RHS) { Z3_ast AST; - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); + assert(getSort(LHS) == getSort(RHS) && "AST's must have the same sort!"); switch (Op) { default: @@ -644,13 +970,15 @@ // Multiplicative operators case BO_Mul: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_mul(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); + SMTExpr RoundingMode = getFloatRoundingMode(); + AST = Z3_mk_fpa_mul(Context.Context, toZ3Expr(RoundingMode).AST, LHS.AST, + RHS.AST); break; } case BO_Div: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_div(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); + SMTExpr RoundingMode = getFloatRoundingMode(); + AST = Z3_mk_fpa_div(Context.Context, toZ3Expr(RoundingMode).AST, LHS.AST, + RHS.AST); break; } case BO_Rem: @@ -659,13 +987,15 @@ // Additive operators case BO_Add: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_add(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); + SMTExpr RoundingMode = getFloatRoundingMode(); + AST = Z3_mk_fpa_add(Context.Context, toZ3Expr(RoundingMode).AST, LHS.AST, + RHS.AST); break; } case BO_Sub: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_sub(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); + SMTExpr RoundingMode = getFloatRoundingMode(); + AST = Z3_mk_fpa_sub(Context.Context, toZ3Expr(RoundingMode).AST, LHS.AST, + RHS.AST); break; } @@ -701,14 +1031,10 @@ } /// Construct a Z3Expr from a SymbolData, given a Z3_context. - Z3Expr fromData(const SymbolID ID, const QualType &Ty, uint64_t BitWidth) { + SMTExpr fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) override { llvm::Twine Name = "$" + llvm::Twine(ID); - - Z3Sort Sort = MkSort(Ty, BitWidth); - - Z3_symbol Symbol = Z3_mk_string_symbol(Context.Context, Name.str().c_str()); - Z3_ast AST = Z3_mk_const(Context.Context, Symbol, Sort.Sort); - return Z3Expr(Context, AST); + return mkSymbol(Name.str().c_str(), mkSort(Ty, BitWidth)); } /// Construct a Z3Expr from a SymbolCast, given a Z3_context. @@ -742,30 +1068,32 @@ } } else if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) { if (ToBitWidth != FromBitWidth) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); - AST = Z3_mk_fpa_to_fp_float(Context.Context, RoundingMode.AST, Exp.AST, - Sort.Sort); + SMTExpr RoundingMode = getFloatRoundingMode(); + SMTSort Sort = getFloatSort(ToBitWidth); + AST = Z3_mk_fpa_to_fp_float(Context.Context, toZ3Expr(RoundingMode).AST, + Exp.AST, toZ3Sort(Sort).Sort); } else { return Exp; } } else if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); + SMTExpr RoundingMode = getFloatRoundingMode(); + SMTSort Sort = getFloatSort(ToBitWidth); AST = FromTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_fp_signed(Context.Context, RoundingMode.AST, - Exp.AST, Sort.Sort) - : Z3_mk_fpa_to_fp_unsigned(Context.Context, RoundingMode.AST, - Exp.AST, Sort.Sort); + ? Z3_mk_fpa_to_fp_signed(Context.Context, + toZ3Expr(RoundingMode).AST, Exp.AST, + toZ3Sort(Sort).Sort) + : Z3_mk_fpa_to_fp_unsigned(Context.Context, + toZ3Expr(RoundingMode).AST, Exp.AST, + toZ3Sort(Sort).Sort); } else if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); + SMTExpr RoundingMode = getFloatRoundingMode(); AST = ToTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_sbv(Context.Context, RoundingMode.AST, Exp.AST, - ToBitWidth) - : Z3_mk_fpa_to_ubv(Context.Context, RoundingMode.AST, Exp.AST, - ToBitWidth); + ? Z3_mk_fpa_to_sbv(Context.Context, toZ3Expr(RoundingMode).AST, + Exp.AST, ToBitWidth) + : Z3_mk_fpa_to_ubv(Context.Context, toZ3Expr(RoundingMode).AST, + Exp.AST, ToBitWidth); } else { llvm_unreachable("Unsupported explicit type cast!"); } @@ -783,39 +1111,38 @@ /// Construct a Z3Expr from a finite APFloat, given a Z3_context. Z3Expr fromAPFloat(const llvm::APFloat &Float) { Z3_ast AST; - Z3Sort Sort = + SMTSort Sort = getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics())); llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false); Z3Expr Z3Int = fromAPSInt(Int); - AST = Z3_mk_fpa_to_fp_bv(Context.Context, Z3Int.AST, Sort.Sort); + AST = Z3_mk_fpa_to_fp_bv(Context.Context, Z3Int.AST, toZ3Sort(Sort).Sort); return Z3Expr(Context, AST); } /// Construct a Z3Expr from an APSInt, given a Z3_context. Z3Expr fromAPSInt(const llvm::APSInt &Int) { - Z3Sort Sort = getBitvectorSort(Int.getBitWidth()); - Z3_ast AST = - Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), Sort.Sort); + SMTSort Sort = getBitvectorSort(Int.getBitWidth()); + Z3_ast AST = Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), + toZ3Sort(Sort).Sort); return Z3Expr(Context, AST); } /// Construct a Z3Expr from an integer, given a Z3_context. Z3Expr fromInt(const char *Int, uint64_t BitWidth) { - Z3Sort Sort = getBitvectorSort(BitWidth); - Z3_ast AST = Z3_mk_numeral(Context.Context, Int, Sort.Sort); + SMTSort Sort = getBitvectorSort(BitWidth); + Z3_ast AST = Z3_mk_numeral(Context.Context, Int, toZ3Sort(Sort).Sort); return Z3Expr(Context, AST); } - /// Construct an APFloat from a Z3Expr, given the AST representation - bool toAPFloat(const Z3Sort &Sort, const Z3_ast &AST, llvm::APFloat &Float, - bool useSemantics = true) { + bool toAPFloat(const SMTSort &Sort, const SMTExpr &AST, llvm::APFloat &Float, + bool useSemantics) { assert(Sort.isFloatSort() && "Unsupported sort to floating-point!"); llvm::APSInt Int(Sort.getFloatSortSize(), true); const llvm::fltSemantics &Semantics = getFloatSemantics(Sort.getFloatSortSize()); - Z3Sort BVSort = getBitvectorSort(Sort.getFloatSortSize()); + SMTSort BVSort = getBitvectorSort(Sort.getFloatSortSize()); if (!toAPSInt(BVSort, AST, Int, true)) { return false; } @@ -829,35 +1156,15 @@ return true; } - /// Construct an APSInt from a Z3Expr, given the AST representation - bool toAPSInt(const Z3Sort &Sort, const Z3_ast &AST, llvm::APSInt &Int, - bool useSemantics = true) { + bool toAPSInt(const SMTSort &Sort, const SMTExpr &AST, llvm::APSInt &Int, + bool useSemantics) { if (Sort.isBitvectorSort()) { if (useSemantics && Int.getBitWidth() != Sort.getBitvectorSortSize()) { assert(false && "Bitvector types don't match!"); return false; } - uint64_t Value[2]; - // Force cast because Z3 defines __uint64 to be a unsigned long long - // type, which isn't compatible with a unsigned long type, even if they - // are the same size. - Z3_get_numeral_uint64(Context.Context, AST, - reinterpret_cast<__uint64 *>(&Value[0])); - if (Sort.getBitvectorSortSize() <= 64) { - Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value[0]), - Int.isUnsigned()); - } else if (Sort.getBitvectorSortSize() == 128) { - Z3Expr ASTHigh = - Z3Expr(Context, Z3_mk_extract(Context.Context, 127, 64, AST)); - Z3_get_numeral_uint64(Context.Context, AST, - reinterpret_cast<__uint64 *>(&Value[1])); - Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value), - Int.isUnsigned()); - } else { - assert(false && "Bitwidth not supported!"); - return false; - } + Int = getBitvector(AST); return true; } @@ -867,11 +1174,8 @@ return false; } - Int = llvm::APSInt( - llvm::APInt(Int.getBitWidth(), - Z3_get_bool_value(Context.Context, AST) == Z3_L_TRUE ? 1 - : 0), - Int.isUnsigned()); + Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), getBoolean(AST)), + Int.isUnsigned()); return true; } @@ -880,31 +1184,31 @@ /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APSInt &Int) { - Z3_func_decl Func = - Z3_get_app_decl(Context.Context, Z3_to_app(Context.Context, Exp.AST)); + bool getInterpretation(const SMTExpr &Exp, llvm::APSInt &Int) override { + Z3Model Model = getModel(); + Z3_func_decl Func = Z3_get_app_decl( + Context.Context, Z3_to_app(Context.Context, toZ3Expr(Exp).AST)); if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.Context, Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExpr Assign = Z3Expr( + Context, Z3_model_get_const_interp(Context.Context, Model.Model, Func)); + SMTSort Sort = getSort(Assign); return toAPSInt(Sort, Assign, Int, true); } /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APFloat &Float) { - Z3_func_decl Func = - Z3_get_app_decl(Context.Context, Z3_to_app(Context.Context, Exp.AST)); + bool getInterpretation(const SMTExpr &Exp, llvm::APFloat &Float) override { + Z3Model Model = getModel(); + Z3_func_decl Func = Z3_get_app_decl( + Context.Context, Z3_to_app(Context.Context, toZ3Expr(Exp).AST)); if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.Context, Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExpr Assign = Z3Expr( + Context, Z3_model_get_const_interp(Context.Context, Model.Model, Func)); + SMTSort Sort = getSort(Assign); return toAPFloat(Sort, Assign, Float, true); } @@ -917,13 +1221,15 @@ } /// Check if the constraints are satisfiable - Z3_lbool check() { return Z3_solver_check(Context.Context, Solver); } + bool check() const override { + return Z3_solver_check(Context.Context, Solver); + } /// Push the current solver state - void push() { return Z3_solver_push(Context.Context, Solver); } + void push() override { return Z3_solver_push(Context.Context, Solver); } /// Pop the previous solver state - void pop(unsigned NumStates = 1) { + void pop(unsigned NumStates = 1) override { assert(Z3_solver_get_num_scopes(Context.Context, Solver) >= NumStates); return Z3_solver_pop(Context.Context, Solver, NumStates); } @@ -935,7 +1241,7 @@ } /// Reset the solver and remove all constraints. - void reset() { Z3_solver_reset(Context.Context, Solver); } + void reset() const override { Z3_solver_reset(Context.Context, Solver); } void print(raw_ostream &OS) const { OS << Z3_solver_to_string(Context.Context, Solver); @@ -1003,7 +1309,7 @@ const Z3Expr &Exp); // Generate and check a Z3 model, using the given constraint. - Z3_lbool checkZ3Model(ProgramStateRef State, const Z3Expr &Exp) const; + bool checkZ3Model(ProgramStateRef State, const Z3Expr &Exp) const; // Generate a Z3Expr that represents the given symbolic expression. // Sets the hasComparison parameter if the expression has a comparison @@ -1165,17 +1471,18 @@ Solver.push(); Solver.addConstraint(Exp); - Z3_lbool isSat = Solver.check(); + bool isSat = Solver.check(); Solver.pop(); Solver.addConstraint(NotExp); - Z3_lbool isNotSat = Solver.check(); + bool isNotSat = Solver.check(); // Zero is the only possible solution - if (isSat == Z3_L_TRUE && isNotSat == Z3_L_FALSE) + if (isSat && !isNotSat) return true; + // Zero is not a solution - else if (isSat == Z3_L_FALSE && isNotSat == Z3_L_TRUE) + if (!isSat && isNotSat) return false; // Zero may be a solution @@ -1202,9 +1509,8 @@ if (Solver.check() != Z3_L_TRUE) return nullptr; - Z3Model Model = Solver.getModel(); // Model does not assign interpretation - if (!Solver.getInterpretation(Model, Exp, Value)) + if (!Solver.getInterpretation(Exp, Value)) return nullptr; // A value has been obtained, check if it is the only value @@ -1297,7 +1603,7 @@ } clang::ento::ConditionTruthVal Z3ConstraintManager::isModelFeasible() { - if (Solver.check() == Z3_L_FALSE) + if (!Solver.check()) return false; return ConditionTruthVal(); @@ -1319,8 +1625,8 @@ return nullptr; } -Z3_lbool Z3ConstraintManager::checkZ3Model(ProgramStateRef State, - const Z3Expr &Exp) const { +bool Z3ConstraintManager::checkZ3Model(ProgramStateRef State, + const Z3Expr &Exp) const { Solver.reset(); Solver.addConstraint(Exp); Solver.addStateConstraints(State); @@ -1393,7 +1699,8 @@ Z3Expr Z3ConstraintManager::getZ3DataExpr(const SymbolID ID, QualType Ty) const { ASTContext &Ctx = getBasicVals().getContext(); - return Solver.fromData(ID, Ty, Ctx.getTypeSize(Ty)); + SMTExpr Data = Solver.fromData(ID, Ty, Ctx.getTypeSize(Ty)); + return toZ3Expr(Data); } Z3Expr Z3ConstraintManager::getZ3CastExpr(const Z3Expr &Exp, QualType FromTy,