Index: cfe/trunk/include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h =================================================================== --- cfe/trunk/include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h +++ cfe/trunk/include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h @@ -0,0 +1,546 @@ +//== SMTSolver.h ------------------------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a SMT generic Solver API, which will be the base class +// for every SMT solver specific class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H +#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H + +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" + +namespace clang { +namespace ento { + +class SMTSolver { +public: + SMTSolver() = default; + virtual ~SMTSolver() = default; + + LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } + + // Return an appropriate floating-point sort for the given bitwidth. + SMTSortRef getFloatSort(unsigned BitWidth) { + switch (BitWidth) { + case 16: + return getFloat16Sort(); + case 32: + return getFloat32Sort(); + case 64: + return getFloat64Sort(); + case 128: + return getFloat128Sort(); + default:; + } + llvm_unreachable("Unsupported floating-point bitwidth!"); + } + + // Return an appropriate sort, given a QualType + SMTSortRef mkSort(const QualType &Ty, unsigned BitWidth) { + if (Ty->isBooleanType()) + return getBoolSort(); + + if (Ty->isRealFloatingType()) + return getFloatSort(BitWidth); + + return getBitvectorSort(BitWidth); + } + + /// Construct a Z3Expr from a unary operator, given a Z3_context. + SMTExprRef fromUnOp(const UnaryOperator::Opcode Op, const SMTExprRef &Exp) { + switch (Op) { + case UO_Minus: + return mkBVNeg(Exp); + + case UO_Not: + return mkBVNot(Exp); + + case UO_LNot: + return mkNot(Exp); + + default:; + } + llvm_unreachable("Unimplemented opcode"); + } + + /// Construct a Z3Expr from a floating-point unary operator, given a + /// Z3_context. + SMTExprRef fromFloatUnOp(const UnaryOperator::Opcode Op, + const SMTExprRef &Exp) { + switch (Op) { + case UO_Minus: + return mkFPNeg(Exp); + + case UO_LNot: + return fromUnOp(Op, Exp); + + default:; + } + llvm_unreachable("Unimplemented opcode"); + } + + /// Construct a Z3Expr from a n-ary binary operator. + SMTExprRef fromNBinOp(const BinaryOperator::Opcode Op, + const std::vector &ASTs) { + assert(!ASTs.empty()); + + if (Op != BO_LAnd && Op != BO_LOr) + llvm_unreachable("Unimplemented opcode"); + + SMTExprRef res = ASTs.front(); + for (std::size_t i = 1; i < ASTs.size(); ++i) + res = (Op == BO_LAnd) ? mkAnd(res, ASTs[i]) : mkOr(res, ASTs[i]); + return res; + } + + /// Construct a Z3Expr from a binary operator, given a Z3_context. + SMTExprRef fromBinOp(const SMTExprRef &LHS, const BinaryOperator::Opcode Op, + const SMTExprRef &RHS, bool isSigned) { + assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!"); + + switch (Op) { + // Multiplicative operators + case BO_Mul: + return mkBVMul(LHS, RHS); + + case BO_Div: + return isSigned ? mkBVSDiv(LHS, RHS) : mkBVUDiv(LHS, RHS); + + case BO_Rem: + return isSigned ? mkBVSRem(LHS, RHS) : mkBVURem(LHS, RHS); + + // Additive operators + case BO_Add: + return mkBVAdd(LHS, RHS); + + case BO_Sub: + return mkBVSub(LHS, RHS); + + // Bitwise shift operators + case BO_Shl: + return mkBVShl(LHS, RHS); + + case BO_Shr: + return isSigned ? mkBVAshr(LHS, RHS) : mkBVLshr(LHS, RHS); + + // Relational operators + case BO_LT: + return isSigned ? mkBVSlt(LHS, RHS) : mkBVUlt(LHS, RHS); + + case BO_GT: + return isSigned ? mkBVSgt(LHS, RHS) : mkBVUgt(LHS, RHS); + + case BO_LE: + return isSigned ? mkBVSle(LHS, RHS) : mkBVUle(LHS, RHS); + + case BO_GE: + return isSigned ? mkBVSge(LHS, RHS) : mkBVUge(LHS, RHS); + + // Equality operators + case BO_EQ: + return mkEqual(LHS, RHS); + + case BO_NE: + return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned)); + + // Bitwise operators + case BO_And: + return mkBVAnd(LHS, RHS); + + case BO_Xor: + return mkBVXor(LHS, RHS); + + case BO_Or: + return mkBVOr(LHS, RHS); + + // Logical operators + case BO_LAnd: + return mkAnd(LHS, RHS); + + case BO_LOr: + return mkOr(LHS, RHS); + + default:; + } + llvm_unreachable("Unimplemented opcode"); + } + + /// Construct a Z3Expr from a special floating-point binary operator, given + /// a Z3_context. + SMTExprRef fromFloatSpecialBinOp(const SMTExprRef &LHS, + const BinaryOperator::Opcode Op, + const llvm::APFloat::fltCategory &RHS) { + switch (Op) { + // Equality operators + case BO_EQ: + switch (RHS) { + case llvm::APFloat::fcInfinity: + return mkFPIsInfinite(LHS); + + case llvm::APFloat::fcNaN: + return mkFPIsNaN(LHS); + + case llvm::APFloat::fcNormal: + return mkFPIsNormal(LHS); + + case llvm::APFloat::fcZero: + return mkFPIsZero(LHS); + } + break; + + case BO_NE: + return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS)); + + default:; + } + + llvm_unreachable("Unimplemented opcode"); + } + + /// Construct a Z3Expr from a floating-point binary operator, given a + /// Z3_context. + SMTExprRef fromFloatBinOp(const SMTExprRef &LHS, + const BinaryOperator::Opcode Op, + const SMTExprRef &RHS) { + assert(*getSort(LHS) == *getSort(RHS) && "AST's must have the same sort!"); + + switch (Op) { + // Multiplicative operators + case BO_Mul: + return mkFPMul(LHS, RHS); + + case BO_Div: + return mkFPDiv(LHS, RHS); + + case BO_Rem: + return mkFPRem(LHS, RHS); + + // Additive operators + case BO_Add: + return mkFPAdd(LHS, RHS); + + case BO_Sub: + return mkFPSub(LHS, RHS); + + // Relational operators + case BO_LT: + return mkFPLt(LHS, RHS); + + case BO_GT: + return mkFPGt(LHS, RHS); + + case BO_LE: + return mkFPLe(LHS, RHS); + + case BO_GE: + return mkFPGe(LHS, RHS); + + // Equality operators + case BO_EQ: + return mkFPEqual(LHS, RHS); + + case BO_NE: + return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS)); + + // Logical operators + case BO_LAnd: + case BO_LOr: + return fromBinOp(LHS, Op, RHS, false); + + default:; + } + + llvm_unreachable("Unimplemented opcode"); + } + + /// Construct a Z3Expr from a SymbolCast, given a Z3_context. + SMTExprRef fromCast(const SMTExprRef &Exp, QualType ToTy, uint64_t ToBitWidth, + QualType FromTy, uint64_t FromBitWidth) { + if ((FromTy->isIntegralOrEnumerationType() && + ToTy->isIntegralOrEnumerationType()) || + (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) || + (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) || + (FromTy->isReferenceType() ^ ToTy->isReferenceType())) { + + if (FromTy->isBooleanType()) { + assert(ToBitWidth > 0 && "BitWidth must be positive!"); + return mkIte(Exp, mkBitvector(llvm::APSInt("1"), ToBitWidth), + mkBitvector(llvm::APSInt("0"), ToBitWidth)); + } + + if (ToBitWidth > FromBitWidth) + return FromTy->isSignedIntegerOrEnumerationType() + ? mkSignExt(ToBitWidth - FromBitWidth, Exp) + : mkZeroExt(ToBitWidth - FromBitWidth, Exp); + + if (ToBitWidth < FromBitWidth) + return mkExtract(ToBitWidth - 1, 0, Exp); + + // Both are bitvectors with the same width, ignore the type cast + return Exp; + } + + if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) { + if (ToBitWidth != FromBitWidth) + return mkFPtoFP(Exp, getFloatSort(ToBitWidth)); + + return Exp; + } + + if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) { + SMTSortRef Sort = getFloatSort(ToBitWidth); + return FromTy->isSignedIntegerOrEnumerationType() ? mkFPtoSBV(Exp, Sort) + : mkFPtoUBV(Exp, Sort); + } + + if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType()) + return ToTy->isSignedIntegerOrEnumerationType() + ? mkSBVtoFP(Exp, ToBitWidth) + : mkUBVtoFP(Exp, ToBitWidth); + + llvm_unreachable("Unsupported explicit type cast!"); + } + + // Callback function for doCast parameter on APSInt type. + llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy, + uint64_t ToWidth, QualType FromTy, + uint64_t FromWidth) { + APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType()); + return TargetType.convert(V); + } + + // Return a boolean sort. + virtual SMTSortRef getBoolSort() = 0; + + // Return an appropriate bitvector sort for the given bitwidth. + virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0; + + // Return a floating-point sort of width 16 + virtual SMTSortRef getFloat16Sort() = 0; + + // Return a floating-point sort of width 32 + virtual SMTSortRef getFloat32Sort() = 0; + + // Return a floating-point sort of width 64 + virtual SMTSortRef getFloat64Sort() = 0; + + // Return a floating-point sort of width 128 + virtual SMTSortRef getFloat128Sort() = 0; + + // Return an appropriate sort for the given AST. + virtual SMTSortRef getSort(const SMTExprRef &AST) = 0; + + // Return a new SMTExprRef from an SMTExpr + virtual SMTExprRef newExprRef(const SMTExpr &E) const = 0; + + /// Given a constraint, add it to the solver + virtual void addConstraint(const SMTExprRef &Exp) const = 0; + + /// Create a bitvector addition operation + virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector subtraction operation + virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector multiplication operation + virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed modulus operation + virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned modulus operation + virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed division operation + virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned division operation + virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector logical shift left operation + virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector arithmetic shift right operation + virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector logical shift right operation + virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector negation operation + virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0; + + /// Create a bitvector not operation + virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0; + + /// Create a bitvector xor operation + virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector or operation + virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector and operation + virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned less-than operation + virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed less-than operation + virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned greater-than operation + virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed greater-than operation + virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned less-equal-than operation + virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed less-equal-than operation + virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector unsigned greater-equal-than operation + virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a bitvector signed greater-equal-than operation + virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + /// Create a boolean not operation + virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0; + + /// Create a bitvector equality operation + virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T, + const SMTExprRef &F) = 0; + + virtual SMTExprRef mkSignExt(unsigned i, const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkZeroExt(unsigned i, const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkExtract(unsigned High, unsigned Low, + const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkConcat(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0; + + virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS, + const SMTExprRef &RHS) = 0; + + virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0; + + virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, + const SMTSortRef &To) = 0; + + virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, + const SMTSortRef &To) = 0; + + virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0; + + virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From, unsigned ToWidth) = 0; + + virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0; + + // Return an appropriate floating-point rounding mode. + virtual SMTExprRef getFloatRoundingMode() = 0; + + virtual const llvm::APSInt getBitvector(const SMTExprRef &Exp) = 0; + + virtual bool getBoolean(const SMTExprRef &Exp) = 0; + + /// Construct a const SMTExprRef &From a boolean. + virtual SMTExprRef mkBoolean(const bool b) = 0; + + /// Construct a const SMTExprRef &From a finite APFloat. + virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0; + + /// Construct a const SMTExprRef &From an APSInt. + virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0; + + SMTExprRef mkBitvector(const llvm::APSInt Int) { + return mkBitvector(Int, Int.getBitWidth()); + } + + /// Given an expression, extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0; + + /// Given an expression extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExprRef &Exp, + llvm::APFloat &Float) = 0; + + /// Construct a Z3Expr from a boolean, given a Z3_context. + virtual SMTExprRef fromBoolean(const bool Bool) = 0; + /// Construct a Z3Expr from a finite APFloat, given a Z3_context. + virtual SMTExprRef fromAPFloat(const llvm::APFloat &Float) = 0; + + /// Construct a Z3Expr from an APSInt, given a Z3_context. + virtual SMTExprRef fromAPSInt(const llvm::APSInt &Int) = 0; + + /// Construct a Z3Expr from an integer, given a Z3_context. + virtual SMTExprRef fromInt(const char *Int, uint64_t BitWidth) = 0; + + /// Construct a const SMTExprRef &From a SymbolData, given a SMT_context. + virtual SMTExprRef fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) = 0; + + /// Check if the constraints are satisfiable + virtual ConditionTruthVal check() const = 0; + + /// Push the current solver state + virtual void push() = 0; + + /// Pop the previous solver state + virtual void pop(unsigned NumStates = 1) = 0; + + /// Reset the solver and remove all constraints. + virtual void reset() const = 0; + + virtual void print(raw_ostream &OS) const = 0; +}; + +using SMTSolverRef = std::shared_ptr; + +} // namespace ento +} // namespace clang + +#endif Index: cfe/trunk/lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp =================================================================== --- cfe/trunk/lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp +++ cfe/trunk/lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp @@ -13,6 +13,7 @@ #include "clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" #include "clang/Config/config.h" @@ -93,11 +94,11 @@ Z3_sort Sort; +public: Z3Sort(Z3Context &C, Z3_sort ZS) : SMTSort(), Context(C), Sort(ZS) { Z3_inc_ref(Context.Context, reinterpret_cast(Sort)); } -public: /// Override implicit copy constructor for correct reference counting. Z3Sort(const Z3Sort &Copy) : SMTSort(), Context(Copy.Context), Sort(Copy.Sort) { @@ -163,6 +164,10 @@ } }; // end class Z3Sort +static const Z3Sort &toZ3Sort(const SMTSort &S) { + return static_cast(S); +} + class Z3Expr : public SMTExpr { friend class Z3Solver; @@ -170,11 +175,11 @@ Z3_ast AST; +public: Z3Expr(Z3Context &C, Z3_ast ZA) : SMTExpr(), Context(C), AST(ZA) { Z3_inc_ref(Context.Context, AST); } -public: /// Override implicit copy constructor for correct reference counting. Z3Expr(const Z3Expr &Copy) : SMTExpr(), Context(Copy.Context), AST(Copy.AST) { Z3_inc_ref(Context.Context, AST); @@ -228,6 +233,10 @@ } }; // end class Z3Expr +static const Z3Expr &toZ3Expr(const SMTExpr &E) { + return static_cast(E); +} + class Z3Model { friend class Z3Solver; @@ -304,25 +313,27 @@ llvm::APFloat::semanticsSizeInBits(RHS)); } -class Z3Solver { +class Z3Solver : public SMTSolver { friend class Z3ConstraintManager; Z3Context Context; Z3_solver Solver; - Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) { + Z3Solver() : SMTSolver(), Solver(Z3_mk_simple_solver(Context.Context)) { Z3_solver_inc_ref(Context.Context, Solver); } public: /// Override implicit copy constructor for correct reference counting. - Z3Solver(const Z3Solver &Copy) : Context(Copy.Context), Solver(Copy.Solver) { + Z3Solver(const Z3Solver &Copy) + : SMTSolver(), Context(Copy.Context), Solver(Copy.Solver) { Z3_solver_inc_ref(Context.Context, Solver); } /// Provide move constructor - Z3Solver(Z3Solver &&Move) : Context(Move.Context), Solver(nullptr) { + Z3Solver(Z3Solver &&Move) + : SMTSolver(), Context(Move.Context), Solver(nullptr) { *this = std::move(Move); } @@ -342,470 +353,460 @@ Z3_solver_dec_ref(Context.Context, Solver); } - /// Given a constraint, add it to the solver - void addConstraint(const Z3Expr &Exp) { - Z3_solver_assert(Context.Context, Solver, Exp.AST); - } - - // Return a boolean sort. - Z3Sort getBoolSort() { - return Z3Sort(Context, Z3_mk_bool_sort(Context.Context)); - } - - // Return an appropriate bitvector sort for the given bitwidth. - Z3Sort getBitvectorSort(unsigned BitWidth) { - return Z3Sort(Context, Z3_mk_bv_sort(Context.Context, BitWidth)); - } - - // Return an appropriate floating-point sort for the given bitwidth. - Z3Sort getFloatSort(unsigned BitWidth) { - Z3_sort Sort; - - switch (BitWidth) { - default: - llvm_unreachable("Unsupported floating-point bitwidth!"); - break; - case 16: - Sort = Z3_mk_fpa_sort_16(Context.Context); - break; - case 32: - Sort = Z3_mk_fpa_sort_32(Context.Context); - break; - case 64: - Sort = Z3_mk_fpa_sort_64(Context.Context); - break; - case 128: - Sort = Z3_mk_fpa_sort_128(Context.Context); - break; - } - return Z3Sort(Context, Sort); + void addConstraint(const SMTExprRef &Exp) const override { + Z3_solver_assert(Context.Context, Solver, toZ3Expr(*Exp).AST); } - // Return an appropriate sort, given a QualType - Z3Sort MkSort(const QualType &Ty, unsigned BitWidth) { - if (Ty->isBooleanType()) - return getBoolSort(); + SMTSortRef getBoolSort() override { + return std::make_shared(Context, Z3_mk_bool_sort(Context.Context)); + } - if (Ty->isRealFloatingType()) - return getFloatSort(BitWidth); + SMTSortRef getBitvectorSort(unsigned BitWidth) override { + return std::make_shared(Context, + Z3_mk_bv_sort(Context.Context, BitWidth)); + } - return getBitvectorSort(BitWidth); + SMTSortRef getSort(const SMTExprRef &Exp) override { + return std::make_shared( + Context, Z3_get_sort(Context.Context, toZ3Expr(*Exp).AST)); } - // Return an appropriate sort for the given AST. - Z3Sort getSort(Z3_ast AST) { - return Z3Sort(Context, Z3_get_sort(Context.Context, AST)); + SMTSortRef getFloat16Sort() override { + return std::make_shared(Context, + Z3_mk_fpa_sort_16(Context.Context)); } - /// Given a program state, construct the logical conjunction and add it to - /// the solver - void addStateConstraints(ProgramStateRef State) { - // TODO: Don't add all the constraints, only the relevant ones - ConstraintZ3Ty CZ = State->get(); - ConstraintZ3Ty::iterator I = CZ.begin(), IE = CZ.end(); - - // Construct the logical AND of all the constraints - if (I != IE) { - std::vector ASTs; + SMTSortRef getFloat32Sort() override { + return std::make_shared(Context, + Z3_mk_fpa_sort_32(Context.Context)); + } - while (I != IE) - ASTs.push_back(I++->second.AST); + SMTSortRef getFloat64Sort() override { + return std::make_shared(Context, + Z3_mk_fpa_sort_64(Context.Context)); + } - Z3Expr Conj = fromNBinOp(BO_LAnd, ASTs); - addConstraint(Conj); - } + SMTSortRef getFloat128Sort() override { + return std::make_shared(Context, + Z3_mk_fpa_sort_128(Context.Context)); } - // Return an appropriate floating-point rounding mode. - Z3Expr getFloatRoundingMode() { - // TODO: Don't assume nearest ties to even rounding mode - return Z3Expr(Context, Z3_mk_fpa_rne(Context.Context)); + SMTExprRef newExprRef(const SMTExpr &E) const override { + return std::make_shared(toZ3Expr(E)); } - /// Construct a Z3Expr from a unary operator, given a Z3_context. - Z3Expr fromUnOp(const UnaryOperator::Opcode Op, const Z3Expr &Exp) { - Z3_ast AST; + SMTExprRef mkBVNeg(const SMTExprRef &Exp) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvneg(Context.Context, toZ3Expr(*Exp).AST))); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExprRef mkBVNot(const SMTExprRef &Exp) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvnot(Context.Context, toZ3Expr(*Exp).AST))); + } - case UO_Minus: - AST = Z3_mk_bvneg(Context.Context, Exp.AST); - break; + SMTExprRef mkNot(const SMTExprRef &Exp) override { + return newExprRef( + Z3Expr(Context, Z3_mk_not(Context.Context, toZ3Expr(*Exp).AST))); + } - case UO_Not: - AST = Z3_mk_bvnot(Context.Context, Exp.AST); - break; + SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvadd(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - case UO_LNot: - AST = Z3_mk_not(Context.Context, Exp.AST); - break; - } + SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsub(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvmul(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a floating-point unary operator, given a - /// Z3_context. - Z3Expr fromFloatUnOp(const UnaryOperator::Opcode Op, const Z3Expr &Exp) { - Z3_ast AST; + SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsrem(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvurem(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - case UO_Minus: - AST = Z3_mk_fpa_neg(Context.Context, Exp.AST); - break; + SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsdiv(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - case UO_LNot: - return fromUnOp(Op, Exp); - } + SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvudiv(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvshl(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a n-ary binary operator. - Z3Expr fromNBinOp(const BinaryOperator::Opcode Op, - const std::vector &ASTs) { - Z3_ast AST; - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - case BO_LAnd: - AST = Z3_mk_and(Context.Context, ASTs.size(), ASTs.data()); - break; - - case BO_LOr: - AST = Z3_mk_or(Context.Context, ASTs.size(), ASTs.data()); - break; - } + SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvashr(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvlshr(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a binary operator, given a Z3_context. - Z3Expr fromBinOp(const Z3Expr &LHS, const BinaryOperator::Opcode Op, - const Z3Expr &RHS, bool isSigned) { - Z3_ast AST; - - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - // Multiplicative operators - case BO_Mul: - AST = Z3_mk_bvmul(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Div: - AST = isSigned ? Z3_mk_bvsdiv(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvudiv(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Rem: - AST = isSigned ? Z3_mk_bvsrem(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvurem(Context.Context, LHS.AST, RHS.AST); - break; - - // Additive operators - case BO_Add: - AST = Z3_mk_bvadd(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Sub: - AST = Z3_mk_bvsub(Context.Context, LHS.AST, RHS.AST); - break; - - // Bitwise shift operators - case BO_Shl: - AST = Z3_mk_bvshl(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Shr: - AST = isSigned ? Z3_mk_bvashr(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvlshr(Context.Context, LHS.AST, RHS.AST); - break; - - // Relational operators - case BO_LT: - AST = isSigned ? Z3_mk_bvslt(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvult(Context.Context, LHS.AST, RHS.AST); - break; - case BO_GT: - AST = isSigned ? Z3_mk_bvsgt(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvugt(Context.Context, LHS.AST, RHS.AST); - break; - case BO_LE: - AST = isSigned ? Z3_mk_bvsle(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvule(Context.Context, LHS.AST, RHS.AST); - break; - case BO_GE: - AST = isSigned ? Z3_mk_bvsge(Context.Context, LHS.AST, RHS.AST) - : Z3_mk_bvuge(Context.Context, LHS.AST, RHS.AST); - break; - - // Equality operators - case BO_EQ: - AST = Z3_mk_eq(Context.Context, LHS.AST, RHS.AST); - break; - case BO_NE: - return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned)); - break; - - // Bitwise operators - case BO_And: - AST = Z3_mk_bvand(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Xor: - AST = Z3_mk_bvxor(Context.Context, LHS.AST, RHS.AST); - break; - case BO_Or: - AST = Z3_mk_bvor(Context.Context, LHS.AST, RHS.AST); - break; - - // Logical operators - case BO_LAnd: - case BO_LOr: { - std::vector Args = {LHS.AST, RHS.AST}; - return fromNBinOp(Op, Args); - } - } + SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvxor(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvor(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a special floating-point binary operator, given - /// a Z3_context. - Z3Expr fromFloatSpecialBinOp(const Z3Expr &LHS, - const BinaryOperator::Opcode Op, - const llvm::APFloat::fltCategory &RHS) { - Z3_ast AST; - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - // Equality operators - case BO_EQ: - switch (RHS) { - case llvm::APFloat::fcInfinity: - AST = Z3_mk_fpa_is_infinite(Context.Context, LHS.AST); - break; - case llvm::APFloat::fcNaN: - AST = Z3_mk_fpa_is_nan(Context.Context, LHS.AST); - break; - case llvm::APFloat::fcNormal: - AST = Z3_mk_fpa_is_normal(Context.Context, LHS.AST); - break; - case llvm::APFloat::fcZero: - AST = Z3_mk_fpa_is_zero(Context.Context, LHS.AST); - break; - } - break; - case BO_NE: - return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS)); - break; - } + SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvand(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvult(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a floating-point binary operator, given a - /// Z3_context. - Z3Expr fromFloatBinOp(const Z3Expr &LHS, const BinaryOperator::Opcode Op, - const Z3Expr &RHS) { - Z3_ast AST; - - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - // Multiplicative operators - case BO_Mul: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_mul(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); - break; - } - case BO_Div: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_div(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); - break; - } - case BO_Rem: - AST = Z3_mk_fpa_rem(Context.Context, LHS.AST, RHS.AST); - break; - - // Additive operators - case BO_Add: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_add(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); - break; - } - case BO_Sub: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_sub(Context.Context, RoundingMode.AST, LHS.AST, RHS.AST); - break; - } + SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvslt(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - // Relational operators - case BO_LT: - AST = Z3_mk_fpa_lt(Context.Context, LHS.AST, RHS.AST); - break; - case BO_GT: - AST = Z3_mk_fpa_gt(Context.Context, LHS.AST, RHS.AST); - break; - case BO_LE: - AST = Z3_mk_fpa_leq(Context.Context, LHS.AST, RHS.AST); - break; - case BO_GE: - AST = Z3_mk_fpa_geq(Context.Context, LHS.AST, RHS.AST); - break; - - // Equality operators - case BO_EQ: - AST = Z3_mk_fpa_eq(Context.Context, LHS.AST, RHS.AST); - break; - case BO_NE: - return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS)); - break; - - // Logical operators - case BO_LAnd: - case BO_LOr: - return fromBinOp(LHS, Op, RHS, false); - } + SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvugt(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsgt(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); } - /// Construct a Z3Expr from a SymbolData, given a Z3_context. - Z3Expr fromData(const SymbolID ID, const QualType &Ty, uint64_t BitWidth) { - llvm::Twine Name = "$" + llvm::Twine(ID); + SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvule(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - Z3Sort Sort = MkSort(Ty, BitWidth); + SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsle(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } - Z3_symbol Symbol = Z3_mk_string_symbol(Context.Context, Name.str().c_str()); - Z3_ast AST = Z3_mk_const(Context.Context, Symbol, Sort.Sort); - return Z3Expr(Context, AST); - } - - /// Construct a Z3Expr from a SymbolCast, given a Z3_context. - Z3Expr fromCast(const Z3Expr &Exp, QualType ToTy, uint64_t ToBitWidth, - QualType FromTy, uint64_t FromBitWidth) { - Z3_ast AST; - - if ((FromTy->isIntegralOrEnumerationType() && - ToTy->isIntegralOrEnumerationType()) || - (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) || - (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) || - (FromTy->isReferenceType() ^ ToTy->isReferenceType())) { - // Special case: Z3 boolean type is distinct from bitvector type, so - // must use if-then-else expression instead of direct cast - if (FromTy->isBooleanType()) { - assert(ToBitWidth > 0 && "BitWidth must be positive!"); - Z3Expr Zero = fromInt("0", ToBitWidth); - Z3Expr One = fromInt("1", ToBitWidth); - AST = Z3_mk_ite(Context.Context, Exp.AST, One.AST, Zero.AST); - } else if (ToBitWidth > FromBitWidth) { - AST = FromTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_sign_ext(Context.Context, ToBitWidth - FromBitWidth, - Exp.AST) - : Z3_mk_zero_ext(Context.Context, ToBitWidth - FromBitWidth, - Exp.AST); - } else if (ToBitWidth < FromBitWidth) { - AST = Z3_mk_extract(Context.Context, ToBitWidth - 1, 0, Exp.AST); - } else { - // Both are bitvectors with the same width, ignore the type cast - return Exp; - } - } else if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) { - if (ToBitWidth != FromBitWidth) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); - AST = Z3_mk_fpa_to_fp_float(Context.Context, RoundingMode.AST, Exp.AST, - Sort.Sort); - } else { - return Exp; - } - } else if (FromTy->isIntegralOrEnumerationType() && - ToTy->isRealFloatingType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); - AST = FromTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_fp_signed(Context.Context, RoundingMode.AST, - Exp.AST, Sort.Sort) - : Z3_mk_fpa_to_fp_unsigned(Context.Context, RoundingMode.AST, - Exp.AST, Sort.Sort); - } else if (FromTy->isRealFloatingType() && - ToTy->isIntegralOrEnumerationType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = ToTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_sbv(Context.Context, RoundingMode.AST, Exp.AST, - ToBitWidth) - : Z3_mk_fpa_to_ubv(Context.Context, RoundingMode.AST, Exp.AST, - ToBitWidth); - } else { - llvm_unreachable("Unsupported explicit type cast!"); - } + SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvuge(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_bvsge(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST}; + return newExprRef(Z3Expr(Context, Z3_mk_and(Context.Context, 2, Args))); + } + + SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST}; + return newExprRef(Z3Expr(Context, Z3_mk_or(Context.Context, 2, Args))); + } + + SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_eq(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPNeg(const SMTExprRef &Exp) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_neg(Context.Context, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) override { + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_is_infinite(Context.Context, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_is_nan(Context.Context, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) override { + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_is_normal(Context.Context, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkFPIsZero(const SMTExprRef &Exp) override { + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_is_zero(Context.Context, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef( + Z3Expr(Context, + Z3_mk_fpa_mul(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST))); + } + + SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef( + Z3Expr(Context, + Z3_mk_fpa_div(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST))); + } + + SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_rem(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef( + Z3Expr(Context, + Z3_mk_fpa_add(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST))); + } + + SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef( + Z3Expr(Context, + Z3_mk_fpa_sub(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST))); + } - return Z3Expr(Context, AST); + SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_lt(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_gt(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_leq(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_geq(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_fpa_eq(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T, + const SMTExprRef &F) override { + return newExprRef( + Z3Expr(Context, Z3_mk_ite(Context.Context, toZ3Expr(*Cond).AST, + toZ3Expr(*T).AST, toZ3Expr(*F).AST))); + } + + SMTExprRef mkSignExt(unsigned i, const SMTExprRef &Exp) override { + return newExprRef(Z3Expr( + Context, Z3_mk_sign_ext(Context.Context, i, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkZeroExt(unsigned i, const SMTExprRef &Exp) override { + return newExprRef(Z3Expr( + Context, Z3_mk_zero_ext(Context.Context, i, toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkExtract(unsigned High, unsigned Low, + const SMTExprRef &Exp) override { + return newExprRef(Z3Expr(Context, Z3_mk_extract(Context.Context, High, Low, + toZ3Expr(*Exp).AST))); + } + + SMTExprRef mkConcat(const SMTExprRef &LHS, const SMTExprRef &RHS) override { + return newExprRef( + Z3Expr(Context, Z3_mk_concat(Context.Context, toZ3Expr(*LHS).AST, + toZ3Expr(*RHS).AST))); + } + + SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef(Z3Expr( + Context, + Z3_mk_fpa_to_fp_float(Context.Context, toZ3Expr(*RoundingMode).AST, + toZ3Expr(*From).AST, toZ3Sort(*To).Sort))); + } + + SMTExprRef mkFPtoSBV(const SMTExprRef &From, const SMTSortRef &To) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef(Z3Expr( + Context, + Z3_mk_fpa_to_fp_signed(Context.Context, toZ3Expr(*RoundingMode).AST, + toZ3Expr(*From).AST, toZ3Sort(*To).Sort))); + } + + SMTExprRef mkFPtoUBV(const SMTExprRef &From, const SMTSortRef &To) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef(Z3Expr( + Context, + Z3_mk_fpa_to_fp_unsigned(Context.Context, toZ3Expr(*RoundingMode).AST, + toZ3Expr(*From).AST, toZ3Sort(*To).Sort))); + } + + SMTExprRef mkSBVtoFP(const SMTExprRef &From, unsigned ToWidth) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_to_sbv(Context.Context, toZ3Expr(*RoundingMode).AST, + toZ3Expr(*From).AST, ToWidth))); + } + + SMTExprRef mkUBVtoFP(const SMTExprRef &From, unsigned ToWidth) override { + SMTExprRef RoundingMode = getFloatRoundingMode(); + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_to_ubv(Context.Context, toZ3Expr(*RoundingMode).AST, + toZ3Expr(*From).AST, ToWidth))); + } + + SMTExprRef mkBoolean(const bool b) override { + return newExprRef(Z3Expr(Context, b ? Z3_mk_true(Context.Context) + : Z3_mk_false(Context.Context))); + } + + SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) override { + const SMTSortRef Sort = getBitvectorSort(BitWidth); + return newExprRef( + Z3Expr(Context, Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), + toZ3Sort(*Sort).Sort))); + } + + SMTExprRef mkFloat(const llvm::APFloat Float) override { + SMTSortRef Sort = + getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics())); + + llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false); + SMTExprRef Z3Int = mkBitvector(Int, Int.getBitWidth()); + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_to_fp_bv(Context.Context, toZ3Expr(*Z3Int).AST, + toZ3Sort(*Sort).Sort))); + } + + SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) override { + return newExprRef( + Z3Expr(Context, Z3_mk_const(Context.Context, + Z3_mk_string_symbol(Context.Context, Name), + toZ3Sort(*Sort).Sort))); + } + + const llvm::APSInt getBitvector(const SMTExprRef &Exp) override { + // FIXME: this returns a string and the bitWidth is overridden + return llvm::APSInt( + Z3_get_numeral_string(Context.Context, toZ3Expr(*Exp).AST)); + } + + bool getBoolean(const SMTExprRef &Exp) override { + return Z3_get_bool_value(Context.Context, toZ3Expr(*Exp).AST) == Z3_L_TRUE; + } + + // Return an appropriate floating-point rounding mode. + SMTExprRef getFloatRoundingMode() override { + // TODO: Don't assume nearest ties to even rounding mode + return newExprRef(Z3Expr(Context, Z3_mk_fpa_rne(Context.Context))); + } + + /// Construct a Z3Expr from a SymbolData, given a Z3_context. + SMTExprRef fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) override { + llvm::Twine Name = "$" + llvm::Twine(ID); + return mkSymbol(Name.str().c_str(), mkSort(Ty, BitWidth)); } /// Construct a Z3Expr from a boolean, given a Z3_context. - Z3Expr fromBoolean(const bool Bool) { + SMTExprRef fromBoolean(const bool Bool) override { Z3_ast AST = Bool ? Z3_mk_true(Context.Context) : Z3_mk_false(Context.Context); - return Z3Expr(Context, AST); + return newExprRef(Z3Expr(Context, AST)); } /// Construct a Z3Expr from a finite APFloat, given a Z3_context. - Z3Expr fromAPFloat(const llvm::APFloat &Float) { - Z3_ast AST; - Z3Sort Sort = + SMTExprRef fromAPFloat(const llvm::APFloat &Float) override { + SMTSortRef Sort = getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics())); llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false); - Z3Expr Z3Int = fromAPSInt(Int); - AST = Z3_mk_fpa_to_fp_bv(Context.Context, Z3Int.AST, Sort.Sort); - return Z3Expr(Context, AST); + SMTExprRef Z3Int = fromAPSInt(Int); + return newExprRef(Z3Expr( + Context, Z3_mk_fpa_to_fp_bv(Context.Context, toZ3Expr(*Z3Int).AST, + toZ3Sort(*Sort).Sort))); } /// Construct a Z3Expr from an APSInt, given a Z3_context. - Z3Expr fromAPSInt(const llvm::APSInt &Int) { - Z3Sort Sort = getBitvectorSort(Int.getBitWidth()); - Z3_ast AST = - Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), Sort.Sort); - return Z3Expr(Context, AST); + SMTExprRef fromAPSInt(const llvm::APSInt &Int) override { + SMTSortRef Sort = getBitvectorSort(Int.getBitWidth()); + Z3_ast AST = Z3_mk_numeral(Context.Context, Int.toString(10).c_str(), + toZ3Sort(*Sort).Sort); + return newExprRef(Z3Expr(Context, AST)); } /// Construct a Z3Expr from an integer, given a Z3_context. - Z3Expr fromInt(const char *Int, uint64_t BitWidth) { - Z3Sort Sort = getBitvectorSort(BitWidth); - Z3_ast AST = Z3_mk_numeral(Context.Context, Int, Sort.Sort); - return Z3Expr(Context, AST); + SMTExprRef fromInt(const char *Int, uint64_t BitWidth) override { + SMTSortRef Sort = getBitvectorSort(BitWidth); + Z3_ast AST = Z3_mk_numeral(Context.Context, Int, toZ3Sort(*Sort).Sort); + return newExprRef(Z3Expr(Context, AST)); } - /// Construct an APFloat from a Z3Expr, given the AST representation - bool toAPFloat(const Z3Sort &Sort, const Z3_ast &AST, llvm::APFloat &Float, - bool useSemantics = true) { - assert(Sort.isFloatSort() && "Unsupported sort to floating-point!"); + bool toAPFloat(const SMTSortRef &Sort, const SMTExprRef &AST, + llvm::APFloat &Float, bool useSemantics) { + assert(Sort->isFloatSort() && "Unsupported sort to floating-point!"); - llvm::APSInt Int(Sort.getFloatSortSize(), true); + llvm::APSInt Int(Sort->getFloatSortSize(), true); const llvm::fltSemantics &Semantics = - getFloatSemantics(Sort.getFloatSortSize()); - Z3Sort BVSort = getBitvectorSort(Sort.getFloatSortSize()); + getFloatSemantics(Sort->getFloatSortSize()); + SMTSortRef BVSort = getBitvectorSort(Sort->getFloatSortSize()); if (!toAPSInt(BVSort, AST, Int, true)) { return false; } @@ -819,11 +820,10 @@ return true; } - /// Construct an APSInt from a Z3Expr, given the AST representation - bool toAPSInt(const Z3Sort &Sort, const Z3_ast &AST, llvm::APSInt &Int, - bool useSemantics = true) { - if (Sort.isBitvectorSort()) { - if (useSemantics && Int.getBitWidth() != Sort.getBitvectorSortSize()) { + bool toAPSInt(const SMTSortRef &Sort, const SMTExprRef &AST, + llvm::APSInt &Int, bool useSemantics) { + if (Sort->isBitvectorSort()) { + if (useSemantics && Int.getBitWidth() != Sort->getBitvectorSortSize()) { assert(false && "Bitvector types don't match!"); return false; } @@ -832,15 +832,14 @@ // Force cast because Z3 defines __uint64 to be a unsigned long long // type, which isn't compatible with a unsigned long type, even if they // are the same size. - Z3_get_numeral_uint64(Context.Context, AST, + Z3_get_numeral_uint64(Context.Context, toZ3Expr(*AST).AST, reinterpret_cast<__uint64 *>(&Value[0])); - if (Sort.getBitvectorSortSize() <= 64) { + if (Sort->getBitvectorSortSize() <= 64) { Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value[0]), Int.isUnsigned()); - } else if (Sort.getBitvectorSortSize() == 128) { - Z3Expr ASTHigh = - Z3Expr(Context, Z3_mk_extract(Context.Context, 127, 64, AST)); - Z3_get_numeral_uint64(Context.Context, AST, + } else if (Sort->getBitvectorSortSize() == 128) { + SMTExprRef ASTHigh = mkExtract(127, 64, AST); + Z3_get_numeral_uint64(Context.Context, toZ3Expr(*AST).AST, reinterpret_cast<__uint64 *>(&Value[1])); Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value), Int.isUnsigned()); @@ -851,17 +850,14 @@ return true; } - if (Sort.isBooleanSort()) { + if (Sort->isBooleanSort()) { if (useSemantics && Int.getBitWidth() < 1) { assert(false && "Boolean type doesn't match!"); return false; } - Int = llvm::APSInt( - llvm::APInt(Int.getBitWidth(), - Z3_get_bool_value(Context.Context, AST) == Z3_L_TRUE ? 1 - : 0), - Int.isUnsigned()); + Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), getBoolean(AST)), + Int.isUnsigned()); return true; } @@ -870,50 +866,53 @@ /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APSInt &Int) { - Z3_func_decl Func = - Z3_get_app_decl(Context.Context, Z3_to_app(Context.Context, Exp.AST)); + bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) override { + Z3Model Model = getModel(); + Z3_func_decl Func = Z3_get_app_decl( + Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST)); if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.Context, Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExprRef Assign = newExprRef( + Z3Expr(Context, + Z3_model_get_const_interp(Context.Context, Model.Model, Func))); + SMTSortRef Sort = getSort(Assign); return toAPSInt(Sort, Assign, Int, true); } /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APFloat &Float) { - Z3_func_decl Func = - Z3_get_app_decl(Context.Context, Z3_to_app(Context.Context, Exp.AST)); + bool getInterpretation(const SMTExprRef &Exp, llvm::APFloat &Float) override { + Z3Model Model = getModel(); + Z3_func_decl Func = Z3_get_app_decl( + Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST)); if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.Context, Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExprRef Assign = newExprRef( + Z3Expr(Context, + Z3_model_get_const_interp(Context.Context, Model.Model, Func))); + SMTSortRef Sort = getSort(Assign); return toAPFloat(Sort, Assign, Float, true); } - // Callback function for doCast parameter on APSInt type. - llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy, - uint64_t ToWidth, QualType FromTy, - uint64_t FromWidth) { - APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType()); - return TargetType.convert(V); - } - /// Check if the constraints are satisfiable - Z3_lbool check() { return Z3_solver_check(Context.Context, Solver); } + ConditionTruthVal check() const override { + Z3_lbool res = Z3_solver_check(Context.Context, Solver); + if (res == Z3_L_TRUE) + return true; + + if (res == Z3_L_FALSE) + return false; + + return ConditionTruthVal(); + } /// Push the current solver state - void push() { return Z3_solver_push(Context.Context, Solver); } + void push() override { return Z3_solver_push(Context.Context, Solver); } /// Pop the previous solver state - void pop(unsigned NumStates = 1) { + void pop(unsigned NumStates = 1) override { assert(Z3_solver_get_num_scopes(Context.Context, Solver) >= NumStates); return Z3_solver_pop(Context.Context, Solver, NumStates); } @@ -925,13 +924,11 @@ } /// Reset the solver and remove all constraints. - void reset() { Z3_solver_reset(Context.Context, Solver); } + void reset() const override { Z3_solver_reset(Context.Context, Solver); } - void print(raw_ostream &OS) const { + void print(raw_ostream &OS) const override { OS << Z3_solver_to_string(Context.Context, Solver); } - - LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } }; // end class Z3Solver class Z3ConstraintManager : public SMTConstraintManager { @@ -988,54 +985,60 @@ // Internal implementation. //===------------------------------------------------------------------===// + /// Given a program state, construct the logical conjunction and add it to + /// the solver + void addStateConstraints(ProgramStateRef State) const; + // Check whether a new model is satisfiable, and update the program state. ProgramStateRef assumeZ3Expr(ProgramStateRef State, SymbolRef Sym, - const Z3Expr &Exp); + const SMTExprRef &Exp); // Generate and check a Z3 model, using the given constraint. - Z3_lbool checkZ3Model(ProgramStateRef State, const Z3Expr &Exp) const; + ConditionTruthVal checkZ3Model(ProgramStateRef State, + const SMTExprRef &Exp) const; // Generate a Z3Expr that represents the given symbolic expression. // Sets the hasComparison parameter if the expression has a comparison // operator. // Sets the RetTy parameter to the final return type after promotions and // casts. - Z3Expr getZ3Expr(SymbolRef Sym, QualType *RetTy = nullptr, - bool *hasComparison = nullptr) const; + SMTExprRef getZ3Expr(SymbolRef Sym, QualType *RetTy = nullptr, + bool *hasComparison = nullptr) const; // Generate a Z3Expr that takes the logical not of an expression. - Z3Expr getZ3NotExpr(const Z3Expr &Exp) const; + SMTExprRef getZ3NotExpr(const SMTExprRef &Exp) const; // Generate a Z3Expr that compares the expression to zero. - Z3Expr getZ3ZeroExpr(const Z3Expr &Exp, QualType RetTy, - bool Assumption) const; + SMTExprRef getZ3ZeroExpr(const SMTExprRef &Exp, QualType RetTy, + bool Assumption) const; // Recursive implementation to unpack and generate symbolic expression. // Sets the hasComparison and RetTy parameters. See getZ3Expr(). - Z3Expr getZ3SymExpr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const; + SMTExprRef getZ3SymExpr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const; // Wrapper to generate Z3Expr from SymbolData. - Z3Expr getZ3DataExpr(const SymbolID ID, QualType Ty) const; + SMTExprRef getZ3DataExpr(const SymbolID ID, QualType Ty) const; // Wrapper to generate Z3Expr from SymbolCast. - Z3Expr getZ3CastExpr(const Z3Expr &Exp, QualType FromTy, QualType Ty) const; + SMTExprRef getZ3CastExpr(const SMTExprRef &Exp, QualType FromTy, + QualType Ty) const; // Wrapper to generate Z3Expr from BinarySymExpr. // Sets the hasComparison and RetTy parameters. See getZ3Expr(). - Z3Expr getZ3SymBinExpr(const BinarySymExpr *BSE, bool *hasComparison, - QualType *RetTy) const; + SMTExprRef getZ3SymBinExpr(const BinarySymExpr *BSE, bool *hasComparison, + QualType *RetTy) const; // Wrapper to generate Z3Expr from unpacked binary symbolic expression. // Sets the RetTy parameter. See getZ3Expr(). - Z3Expr getZ3BinExpr(const Z3Expr &LHS, QualType LTy, - BinaryOperator::Opcode Op, const Z3Expr &RHS, - QualType RTy, QualType *RetTy) const; + SMTExprRef getZ3BinExpr(const SMTExprRef &LHS, QualType LTy, + BinaryOperator::Opcode Op, const SMTExprRef &RHS, + QualType RTy, QualType *RetTy) const; // Wrapper to generate Z3Expr from a range. If From == To, an equality will // be created instead. - Z3Expr getZ3RangeExpr(SymbolRef Sym, const llvm::APSInt &From, - const llvm::APSInt &To, bool InRange); + SMTExprRef getZ3RangeExpr(SymbolRef Sym, const llvm::APSInt &From, + const llvm::APSInt &To, bool InRange); //===------------------------------------------------------------------===// // Helper functions. @@ -1051,33 +1054,49 @@ // Perform implicit type conversion on binary symbolic expressions. // May modify all input parameters. // TODO: Refactor to use built-in conversion functions - void doTypeConversion(Z3Expr &LHS, Z3Expr &RHS, QualType <y, + void doTypeConversion(SMTExprRef &LHS, SMTExprRef &RHS, QualType <y, QualType &RTy) const; // Perform implicit integer type conversion. // May modify all input parameters. // TODO: Refactor to use Sema::handleIntegerConversion() - template + template void doIntTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const; // Perform implicit floating-point type conversion. // May modify all input parameters. // TODO: Refactor to use Sema::handleFloatConversion() - template + template void doFloatTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const; }; // end class Z3ConstraintManager } // end anonymous namespace +void Z3ConstraintManager::addStateConstraints(ProgramStateRef State) const { + // TODO: Don't add all the constraints, only the relevant ones + ConstraintZ3Ty CZ = State->get(); + ConstraintZ3Ty::iterator I = CZ.begin(), IE = CZ.end(); + + // Construct the logical AND of all the constraints + if (I != IE) { + std::vector ASTs; + + while (I != IE) + ASTs.push_back(Solver.newExprRef(Z3Expr(I++->second))); + + Solver.addConstraint(Solver.fromNBinOp(BO_LAnd, ASTs)); + } +} + ProgramStateRef Z3ConstraintManager::assumeSym(ProgramStateRef State, SymbolRef Sym, bool Assumption) { QualType RetTy; bool hasComparison; - Z3Expr Exp = getZ3Expr(Sym, &RetTy, &hasComparison); + SMTExprRef Exp = getZ3Expr(Sym, &RetTy, &hasComparison); // Create zero comparison for implicit boolean cast, with reversed assumption if (!hasComparison && !RetTy->isBooleanType()) return assumeZ3Expr(State, Sym, getZ3ZeroExpr(Exp, RetTy, !Assumption)); @@ -1145,27 +1164,29 @@ SymbolRef Sym) { QualType RetTy; // The expression may be casted, so we cannot call getZ3DataExpr() directly - Z3Expr VarExp = getZ3Expr(Sym, &RetTy); - Z3Expr Exp = getZ3ZeroExpr(VarExp, RetTy, true); + SMTExprRef VarExp = getZ3Expr(Sym, &RetTy); + SMTExprRef Exp = getZ3ZeroExpr(VarExp, RetTy, true); + // Negate the constraint - Z3Expr NotExp = getZ3ZeroExpr(VarExp, RetTy, false); + SMTExprRef NotExp = getZ3ZeroExpr(VarExp, RetTy, false); Solver.reset(); - Solver.addStateConstraints(State); + addStateConstraints(State); Solver.push(); Solver.addConstraint(Exp); - Z3_lbool isSat = Solver.check(); + ConditionTruthVal isSat = Solver.check(); Solver.pop(); Solver.addConstraint(NotExp); - Z3_lbool isNotSat = Solver.check(); + ConditionTruthVal isNotSat = Solver.check(); // Zero is the only possible solution - if (isSat == Z3_L_TRUE && isNotSat == Z3_L_FALSE) + if (isSat.isConstrainedTrue() && isNotSat.isConstrainedFalse()) return true; + // Zero is not a solution - else if (isSat == Z3_L_FALSE && isNotSat == Z3_L_TRUE) + if (isSat.isConstrainedFalse() && isNotSat.isConstrainedTrue()) return false; // Zero may be a solution @@ -1183,29 +1204,31 @@ llvm::APSInt Value(Ctx.getTypeSize(Ty), !Ty->isSignedIntegerOrEnumerationType()); - Z3Expr Exp = getZ3DataExpr(SD->getSymbolID(), Ty); + SMTExprRef Exp = getZ3DataExpr(SD->getSymbolID(), Ty); Solver.reset(); - Solver.addStateConstraints(State); + addStateConstraints(State); // Constraints are unsatisfiable - if (Solver.check() != Z3_L_TRUE) + ConditionTruthVal isSat = Solver.check(); + if (!isSat.isConstrainedTrue()) return nullptr; - Z3Model Model = Solver.getModel(); // Model does not assign interpretation - if (!Solver.getInterpretation(Model, Exp, Value)) + if (!Solver.getInterpretation(Exp, Value)) return nullptr; // A value has been obtained, check if it is the only value - Z3Expr NotExp = Solver.fromBinOp( + SMTExprRef NotExp = Solver.fromBinOp( Exp, BO_NE, Ty->isBooleanType() ? Solver.fromBoolean(Value.getBoolValue()) : Solver.fromAPSInt(Value), false); Solver.addConstraint(NotExp); - if (Solver.check() == Z3_L_TRUE) + + ConditionTruthVal isNotSat = Solver.check(); + if (isNotSat.isConstrainedTrue()) return nullptr; // This is the only solution, store it @@ -1244,8 +1267,8 @@ QualType LTy, RTy; std::tie(ConvertedLHS, LTy) = fixAPSInt(*LHS); std::tie(ConvertedRHS, RTy) = fixAPSInt(*RHS); - doIntTypeConversion(ConvertedLHS, LTy, - ConvertedRHS, RTy); + doIntTypeConversion( + ConvertedLHS, LTy, ConvertedRHS, RTy); return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS); } @@ -1272,9 +1295,9 @@ for (const auto &I : CR) { SymbolRef Sym = I.first; - Z3Expr Constraints = Solver.fromBoolean(false); + SMTExprRef Constraints = Solver.fromBoolean(false); for (const auto &Range : I.second) { - Z3Expr SymRange = + SMTExprRef SymRange = getZ3RangeExpr(Sym, Range.From(), Range.To(), /*InRange=*/true); // FIXME: the last argument (isSigned) is not used when generating the @@ -1287,10 +1310,7 @@ } clang::ento::ConditionTruthVal Z3ConstraintManager::isModelFeasible() { - if (Solver.check() == Z3_L_FALSE) - return false; - - return ConditionTruthVal(); + return Solver.check(); } LLVM_DUMP_METHOD void Z3ConstraintManager::dump() const { Solver.dump(); } @@ -1301,24 +1321,25 @@ ProgramStateRef Z3ConstraintManager::assumeZ3Expr(ProgramStateRef State, SymbolRef Sym, - const Z3Expr &Exp) { + const SMTExprRef &Exp) { // Check the model, avoid simplifying AST to save time - if (checkZ3Model(State, Exp) == Z3_L_TRUE) - return State->add(std::make_pair(Sym, Exp)); + if (checkZ3Model(State, Exp).isConstrainedTrue()) + return State->add(std::make_pair(Sym, toZ3Expr(*Exp))); return nullptr; } -Z3_lbool Z3ConstraintManager::checkZ3Model(ProgramStateRef State, - const Z3Expr &Exp) const { +ConditionTruthVal +Z3ConstraintManager::checkZ3Model(ProgramStateRef State, + const SMTExprRef &Exp) const { Solver.reset(); Solver.addConstraint(Exp); - Solver.addStateConstraints(State); + addStateConstraints(State); return Solver.check(); } -Z3Expr Z3ConstraintManager::getZ3Expr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const { +SMTExprRef Z3ConstraintManager::getZ3Expr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const { if (hasComparison) { *hasComparison = false; } @@ -1326,12 +1347,13 @@ return getZ3SymExpr(Sym, RetTy, hasComparison); } -Z3Expr Z3ConstraintManager::getZ3NotExpr(const Z3Expr &Exp) const { +SMTExprRef Z3ConstraintManager::getZ3NotExpr(const SMTExprRef &Exp) const { return Solver.fromUnOp(UO_LNot, Exp); } -Z3Expr Z3ConstraintManager::getZ3ZeroExpr(const Z3Expr &Exp, QualType Ty, - bool Assumption) const { +SMTExprRef Z3ConstraintManager::getZ3ZeroExpr(const SMTExprRef &Exp, + QualType Ty, + bool Assumption) const { ASTContext &Ctx = getBasicVals().getContext(); if (Ty->isRealFloatingType()) { llvm::APFloat Zero = llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty)); @@ -1350,8 +1372,8 @@ llvm_unreachable("Unsupported type for zero value!"); } -Z3Expr Z3ConstraintManager::getZ3SymExpr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const { +SMTExprRef Z3ConstraintManager::getZ3SymExpr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const { if (const SymbolData *SD = dyn_cast(Sym)) { if (RetTy) *RetTy = Sym->getType(); @@ -1362,7 +1384,7 @@ *RetTy = Sym->getType(); QualType FromTy; - Z3Expr Exp = getZ3SymExpr(SC->getOperand(), &FromTy, hasComparison); + SMTExprRef Exp = getZ3SymExpr(SC->getOperand(), &FromTy, hasComparison); // Casting an expression with a comparison invalidates it. Note that this // must occur after the recursive call above. // e.g. (signed char) (x > 0) @@ -1370,7 +1392,7 @@ *hasComparison = false; return getZ3CastExpr(Exp, FromTy, Sym->getType()); } else if (const BinarySymExpr *BSE = dyn_cast(Sym)) { - Z3Expr Exp = getZ3SymBinExpr(BSE, hasComparison, RetTy); + SMTExprRef Exp = getZ3SymBinExpr(BSE, hasComparison, RetTy); // Set the hasComparison parameter, in post-order traversal order. if (hasComparison) *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode()); @@ -1380,52 +1402,52 @@ llvm_unreachable("Unsupported SymbolRef type!"); } -Z3Expr Z3ConstraintManager::getZ3DataExpr(const SymbolID ID, - QualType Ty) const { +SMTExprRef Z3ConstraintManager::getZ3DataExpr(const SymbolID ID, + QualType Ty) const { ASTContext &Ctx = getBasicVals().getContext(); return Solver.fromData(ID, Ty, Ctx.getTypeSize(Ty)); } -Z3Expr Z3ConstraintManager::getZ3CastExpr(const Z3Expr &Exp, QualType FromTy, - QualType ToTy) const { +SMTExprRef Z3ConstraintManager::getZ3CastExpr(const SMTExprRef &Exp, + QualType FromTy, + QualType ToTy) const { ASTContext &Ctx = getBasicVals().getContext(); return Solver.fromCast(Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy, Ctx.getTypeSize(FromTy)); } -Z3Expr Z3ConstraintManager::getZ3SymBinExpr(const BinarySymExpr *BSE, - bool *hasComparison, - QualType *RetTy) const { +SMTExprRef Z3ConstraintManager::getZ3SymBinExpr(const BinarySymExpr *BSE, + bool *hasComparison, + QualType *RetTy) const { QualType LTy, RTy; BinaryOperator::Opcode Op = BSE->getOpcode(); if (const SymIntExpr *SIE = dyn_cast(BSE)) { - Z3Expr LHS = getZ3SymExpr(SIE->getLHS(), <y, hasComparison); + SMTExprRef LHS = getZ3SymExpr(SIE->getLHS(), <y, hasComparison); llvm::APSInt NewRInt; std::tie(NewRInt, RTy) = fixAPSInt(SIE->getRHS()); - Z3Expr RHS = Solver.fromAPSInt(NewRInt); + SMTExprRef RHS = Solver.fromAPSInt(NewRInt); return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); } else if (const IntSymExpr *ISE = dyn_cast(BSE)) { llvm::APSInt NewLInt; std::tie(NewLInt, LTy) = fixAPSInt(ISE->getLHS()); - Z3Expr LHS = Solver.fromAPSInt(NewLInt); - Z3Expr RHS = getZ3SymExpr(ISE->getRHS(), &RTy, hasComparison); + SMTExprRef LHS = Solver.fromAPSInt(NewLInt); + SMTExprRef RHS = getZ3SymExpr(ISE->getRHS(), &RTy, hasComparison); return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); } else if (const SymSymExpr *SSM = dyn_cast(BSE)) { - Z3Expr LHS = getZ3SymExpr(SSM->getLHS(), <y, hasComparison); - Z3Expr RHS = getZ3SymExpr(SSM->getRHS(), &RTy, hasComparison); + SMTExprRef LHS = getZ3SymExpr(SSM->getLHS(), <y, hasComparison); + SMTExprRef RHS = getZ3SymExpr(SSM->getRHS(), &RTy, hasComparison); return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); } else { llvm_unreachable("Unsupported BinarySymExpr type!"); } } -Z3Expr Z3ConstraintManager::getZ3BinExpr(const Z3Expr &LHS, QualType LTy, - BinaryOperator::Opcode Op, - const Z3Expr &RHS, QualType RTy, - QualType *RetTy) const { - Z3Expr NewLHS = LHS; - Z3Expr NewRHS = RHS; +SMTExprRef Z3ConstraintManager::getZ3BinExpr( + const SMTExprRef &LHS, QualType LTy, BinaryOperator::Opcode Op, + const SMTExprRef &RHS, QualType RTy, QualType *RetTy) const { + SMTExprRef NewLHS = LHS; + SMTExprRef NewRHS = RHS; doTypeConversion(NewLHS, NewRHS, LTy, RTy); // Update the return type parameter if the output type has changed. if (RetTy) { @@ -1452,19 +1474,19 @@ LTy->isSignedIntegerOrEnumerationType()); } -Z3Expr Z3ConstraintManager::getZ3RangeExpr(SymbolRef Sym, - const llvm::APSInt &From, - const llvm::APSInt &To, - bool InRange) { +SMTExprRef Z3ConstraintManager::getZ3RangeExpr(SymbolRef Sym, + const llvm::APSInt &From, + const llvm::APSInt &To, + bool InRange) { // Convert lower bound QualType FromTy; llvm::APSInt NewFromInt; std::tie(NewFromInt, FromTy) = fixAPSInt(From); - Z3Expr FromExp = Solver.fromAPSInt(NewFromInt); + SMTExprRef FromExp = Solver.fromAPSInt(NewFromInt); // Convert symbol QualType SymTy; - Z3Expr Exp = getZ3Expr(Sym, &SymTy); + SMTExprRef Exp = getZ3Expr(Sym, &SymTy); // Construct single (in)equality if (From == To) @@ -1474,14 +1496,15 @@ QualType ToTy; llvm::APSInt NewToInt; std::tie(NewToInt, ToTy) = fixAPSInt(To); - Z3Expr ToExp = Solver.fromAPSInt(NewToInt); + SMTExprRef ToExp = Solver.fromAPSInt(NewToInt); assert(FromTy == ToTy && "Range values have different types!"); // Construct two (in)equalities, and a logical and/or - Z3Expr LHS = getZ3BinExpr(Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp, - FromTy, /*RetTy=*/nullptr); - Z3Expr RHS = getZ3BinExpr(Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy, - /*RetTy=*/nullptr); + SMTExprRef LHS = getZ3BinExpr(Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp, + FromTy, /*RetTy=*/nullptr); + SMTExprRef RHS = + getZ3BinExpr(Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy, + /*RetTy=*/nullptr); return Solver.fromBinOp(LHS, InRange ? BO_LAnd : BO_LOr, RHS, SymTy->isSignedIntegerOrEnumerationType()); @@ -1512,7 +1535,7 @@ return std::make_pair(NewInt, getAPSIntType(NewInt)); } -void Z3ConstraintManager::doTypeConversion(Z3Expr &LHS, Z3Expr &RHS, +void Z3ConstraintManager::doTypeConversion(SMTExprRef &LHS, SMTExprRef &RHS, QualType <y, QualType &RTy) const { ASTContext &Ctx = getBasicVals().getContext(); @@ -1520,12 +1543,13 @@ // Perform type conversion if (LTy->isIntegralOrEnumerationType() && RTy->isIntegralOrEnumerationType()) { - if (LTy->isArithmeticType() && RTy->isArithmeticType()) - return doIntTypeConversion(LHS, LTy, RHS, - RTy); + if (LTy->isArithmeticType() && RTy->isArithmeticType()) { + doIntTypeConversion(LHS, LTy, RHS, RTy); + return; + } } else if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) { - return doFloatTypeConversion(LHS, LTy, RHS, - RTy); + doFloatTypeConversion(LHS, LTy, RHS, RTy); + return; } else if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) || (LTy->isBlockPointerType() || RTy->isBlockPointerType()) || (LTy->isReferenceType() || RTy->isReferenceType())) { @@ -1576,8 +1600,8 @@ // TODO: Refine behavior for invalid type casts } -template +template void Z3ConstraintManager::doIntTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const { ASTContext &Ctx = getBasicVals().getContext(); @@ -1654,8 +1678,8 @@ } } -template +template void Z3ConstraintManager::doFloatTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const { ASTContext &Ctx = getBasicVals().getContext();