Index: include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h =================================================================== --- include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h +++ include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h @@ -16,27 +16,309 @@ #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTCONSTRAINTMANAGER_H #include "clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h" namespace clang { namespace ento { class SMTConstraintManager : public clang::ento::SimpleConstraintManager { +protected: + SMTSolver &Solver; public: - SMTConstraintManager(clang::ento::SubEngine *SE, clang::ento::SValBuilder &SB) - : SimpleConstraintManager(SE, SB) {} + SMTConstraintManager(clang::ento::SubEngine *SE, clang::ento::SValBuilder &SB, + SMTSolver S) + : SimpleConstraintManager(SE, SB), Solver(S) {} + virtual ~SMTConstraintManager() = default; - /// Converts the ranged constraints of a set of symbols to SMT - /// - /// \param CR The set of constraints. - virtual void addRangeConstraints(clang::ento::ConstraintRangeTy CR) = 0; + /// Converts a ranged constraints to SMT + void addRangeConstraints(clang::ento::ConstraintRangeTy CR); /// Checks if the added constraints are satisfiable - virtual clang::ento::ConditionTruthVal isModelFeasible() = 0; + clang::ento::ConditionTruthVal isModelFeasible(); + + /// Construct a SMTExpr from a unary operator. + SMTExpr fromUnOp(const UnaryOperator::Opcode Op, const SMTExpr &Exp) const; + + /// Construct a SMTExpr from a floating-point unary operator. + SMTExpr fromFloatUnOp(const UnaryOperator::Opcode Op, + const SMTExpr &Exp) const; + + /// Construct a SMTExpr from a n-ary binary operator. + SMTExpr fromNBinOp(const BinaryOperator::Opcode Op, + const std::vector &ASTs) const; + + /// Construct a SMTExpr from a binary operator, given a SMT_context. + SMTExpr fromBinOp(const SMTExpr &LHS, const BinaryOperator::Opcode Op, + const SMTExpr &RHS, bool isSigned) const; + + /// Construct a SMTExpr from a special floating-point binary operator, given + /// a SMT_context. + SMTExpr fromFloatSpecialBinOp(const SMTExpr &LHS, + const BinaryOperator::Opcode Op, + const llvm::APFloat::fltCategory &RHS) const; + + /// Construct a SMTExpr from a floating-point binary operator, given a + /// SMT_context. + SMTExpr fromFloatBinOp(const SMTExpr &LHS, const BinaryOperator::Opcode Op, + const SMTExpr &RHS) const; + + /// Construct a SMTExpr from a SymbolCast, given a SMT_context. + SMTExpr fromCast(const SMTExpr &Exp, QualType ToTy, uint64_t ToBitWidth, + QualType FromTy, uint64_t FromBitWidth) const; + + SMTExpr fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) const; + + /// Construct an APFloat from a SMTExpr, given the AST representation + bool toAPFloat(const SMTSort &Sort, const SMTExpr &AST, llvm::APFloat &Float, + bool useSemantics = true) const; + + /// Construct an APSInt from a SMTExpr, given the AST representation + bool toAPSInt(const SMTSort &Sort, const SMTExpr &AST, llvm::APSInt &Int, + bool useSemantics = true) const; + + // Callback function for doCast parameter on APSInt type. + llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy, + uint64_t ToWidth, QualType FromTy, + uint64_t FromWidth) const; + + ConditionTruthVal checkNull(ProgramStateRef State, SymbolRef Sym) override; + + const llvm::APSInt *getSymVal(ProgramStateRef State, + SymbolRef Sym) const override; + + ProgramStateRef removeDeadBindings(ProgramStateRef St, + SymbolReaper &SymReaper) override; + + void print(ProgramStateRef St, raw_ostream &Out, const char *nl, + const char *sep) override; /// Dumps SMT formula - LLVM_DUMP_METHOD virtual void dump() const = 0; + LLVM_DUMP_METHOD void dump() const { Solver.dump(); } + + //===------------------------------------------------------------------===// + // Implementation for interface from SimpleConstraintManager. + //===------------------------------------------------------------------===// + + ProgramStateRef assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym, + bool Assumption) override; + + ProgramStateRef assumeSym(ProgramStateRef state, SymbolRef Sym, + bool Assumption) override; + + ProgramStateRef assumeSymInclusiveRange(ProgramStateRef State, SymbolRef Sym, + const llvm::APSInt &From, + const llvm::APSInt &To, + bool InRange) override; + + //===------------------------------------------------------------------===// + // Implementation for interface from ConstraintManager. + //===------------------------------------------------------------------===// + + bool canReasonAbout(SVal X) const override; + +private: + //===------------------------------------------------------------------===// + // Internal implementation. + //===------------------------------------------------------------------===// + + /// Given a program state, construct the logical conjunction and add it to + /// the solver + void addStateConstraints(ProgramStateRef State) const; + + // Check whether a new model is satisfiable, and update the program state. + ProgramStateRef assumeSMTExpr(ProgramStateRef State, SymbolRef Sym, + const SMTExpr &Exp); + + // Generate and check a Z3 model, using the given constraint. + bool checkModel(ProgramStateRef State, const SMTExpr &Exp); + + // Generate a SMTExpr that represents the given symbolic expression. + // Sets the hasComparison parameter if the expression has a comparison + // operator. + // Sets the RetTy parameter to the final return type after promotions and + // casts. + SMTExpr getSMTExpr(SymbolRef Sym, QualType *RetTy = nullptr, + bool *hasComparison = nullptr) const; + + // Generate a SMTExpr that takes the logical not of an expression. + SMTExpr getSMTNotExpr(const SMTExpr &Exp) const; + + // Generate a SMTExpr that compares the expression to zero. + SMTExpr getSMTZeroExpr(const SMTExpr &Exp, QualType RetTy, + bool Assumption) const; + + // Recursive implementation to unpack and generate symbolic expression. + // Sets the hasComparison and RetTy parameters. See getSMTExpr(). + SMTExpr getSMTSymExpr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const; + + // Wrapper to generate SMTExpr from SymbolData. + SMTExpr getSMTDataExpr(const SymbolID ID, QualType Ty) const; + + // Wrapper to generate SMTExpr from SymbolCast. + SMTExpr getSMTCastExpr(const SMTExpr &Exp, QualType FromTy, + QualType Ty) const; + + // Wrapper to generate SMTExpr from BinarySymExpr. + // Sets the hasComparison and RetTy parameters. See getSMTExpr(). + SMTExpr getSMTSymBinExpr(const BinarySymExpr *BSE, bool *hasComparison, + QualType *RetTy) const; + + // Wrapper to generate SMTExpr from unpacked binary symbolic expression. + // Sets the RetTy parameter. See getSMTExpr(). + SMTExpr getSMTBinExpr(const SMTExpr &LHS, QualType LTy, + BinaryOperator::Opcode Op, const SMTExpr &RHS, + QualType RTy, QualType *RetTy) const; + + // Wrapper to generate SMTExpr from a range. If From == To, an equality will + // be created instead. + SMTExpr getSMTRangeExpr(SymbolRef Sym, const llvm::APSInt &From, + const llvm::APSInt &To, bool InRange); + + //===------------------------------------------------------------------===// + // Helper functions. + //===------------------------------------------------------------------===// + + // Recover the QualType of an APSInt. + // TODO: Refactor to put elsewhere + QualType getAPSIntType(const llvm::APSInt &Int) const; + + // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1. + std::pair fixAPSInt(const llvm::APSInt &Int) const; + + // Perform implicit type conversion on binary symbolic expressions. + // May modify all input parameters. + // TODO: Refactor to use built-in conversion functions + void doTypeConversion(SMTExpr &LHS, SMTExpr &RHS, QualType <y, + QualType &RTy) const; + + // Perform implicit integer type conversion. + // May modify all input parameters. + // TODO: Refactor to use Sema::handleIntegerConversion() + template + void doIntTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const { + ASTContext &Ctx = getBasicVals().getContext(); + uint64_t LBitWidth = Ctx.getTypeSize(LTy); + uint64_t RBitWidth = Ctx.getTypeSize(RTy); + + assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!"); + // Always perform integer promotion before checking type equality. + // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion + if (LTy->isPromotableIntegerType()) { + QualType NewTy = Ctx.getPromotedIntegerType(LTy); + uint64_t NewBitWidth = Ctx.getTypeSize(NewTy); + LHS = (this->*doCast)(LHS, NewTy, NewBitWidth, LTy, LBitWidth); + LTy = NewTy; + LBitWidth = NewBitWidth; + } + if (RTy->isPromotableIntegerType()) { + QualType NewTy = Ctx.getPromotedIntegerType(RTy); + uint64_t NewBitWidth = Ctx.getTypeSize(NewTy); + RHS = (this->*doCast)(RHS, NewTy, NewBitWidth, RTy, RBitWidth); + RTy = NewTy; + RBitWidth = NewBitWidth; + } + + if (LTy == RTy) + return; + + // Perform integer type conversion + // Note: Safe to skip updating bitwidth because this must terminate + bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType(); + bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType(); + + int order = Ctx.getIntegerTypeOrder(LTy, RTy); + if (isLSignedTy == isRSignedTy) { + // Same signedness; use the higher-ranked type + if (order == 1) { + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + } else { + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + } + } else if (order != (isLSignedTy ? 1 : -1)) { + // The unsigned type has greater than or equal rank to the + // signed type, so use the unsigned type + if (isRSignedTy) { + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + } else { + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + } + } else if (LBitWidth != RBitWidth) { + // The two types are different widths; if we are here, that + // means the signed type is larger than the unsigned type, so + // use the signed type. + if (isLSignedTy) { + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + } else { + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + } + } else { + // The signed type is higher-ranked than the unsigned type, + // but isn't actually any bigger (like unsigned int and long + // on most 32-bit systems). Use the unsigned type corresponding + // to the signed type. + QualType NewTy = + Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy); + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = NewTy; + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = NewTy; + } + } + + // Perform implicit floating-point type conversion. + // May modify all input parameters. + // TODO: Refactor to use Sema::handleFloatConversion() + template + void doFloatTypeConversion(T &LHS, QualType <y, T &RHS, + QualType &RTy) const { + ASTContext &Ctx = getBasicVals().getContext(); + + uint64_t LBitWidth = Ctx.getTypeSize(LTy); + uint64_t RBitWidth = Ctx.getTypeSize(RTy); + + // Perform float-point type promotion + if (!LTy->isRealFloatingType()) { + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + LBitWidth = RBitWidth; + } + if (!RTy->isRealFloatingType()) { + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + RBitWidth = LBitWidth; + } + + if (LTy == RTy) + return; + + // If we have two real floating types, convert the smaller operand to the + // bigger result + // Note: Safe to skip updating bitwidth because this must terminate + int order = Ctx.getFloatingTypeOrder(LTy, RTy); + if (order > 0) { + RHS = (this->*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + } else if (order == 0) { + LHS = (this->*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + } else { + llvm_unreachable("Unsupported floating-point type cast!"); + } + } }; // end class SMTConstraintManager } // namespace ento Index: include/clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h =================================================================== --- /dev/null +++ include/clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h @@ -0,0 +1,77 @@ +//== SMTExpr.h --------------------------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a SMT generic Expr API, which will be the base class +// for every SMT solver expr specific class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTEXPR_H +#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTEXPR_H + +#include "clang/StaticAnalyzer/Core/PathSensitive/ProgramStateTrait.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" + +namespace clang { +namespace ento { + +class SMTExpr { +public: + SMTExpr() = default; + virtual ~SMTExpr() = default; + + bool operator<(const SMTExpr &Other) const { + llvm::FoldingSetNodeID ID1, ID2; + Profile(ID1); + Other.Profile(ID2); + return ID1 < ID2; + } + + virtual void Profile(llvm::FoldingSetNodeID &ID) const { + static int Tag = 0; + ID.AddPointer(&Tag); + } + + friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) { + return LHS.equal_to(RHS); + } + + virtual void print(raw_ostream &OS) const {}; + +protected: + virtual bool equal_to(SMTExpr const &other) const { return false; } +}; + +template class SMTSolverExpr : public SMTExpr { +public: + SMTSolverExpr(SMTContext &C, E Expr) : SMTExpr(), Context(C), AST(Expr) {} + virtual ~SMTSolverExpr() = default; + + E getExpr() const { return AST; } + + virtual void Profile(llvm::FoldingSetNodeID &ID) const override { + static int Tag = 0; + ID.AddPointer(&Tag); + } + + virtual bool equal_to(SMTExpr const &other) const = 0; + + virtual void print(raw_ostream &OS) const override{}; + + LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } + +protected: + SMTContext &Context; + E AST; +}; + +} // namespace ento +} // namespace clang + +#endif Index: include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h =================================================================== --- /dev/null +++ include/clang/StaticAnalyzer/Core/PathSensitive/SMTSolver.h @@ -0,0 +1,410 @@ +//== SMTSolver.h ------------------------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a SMT generic Solver API, which will be the base class +// for every SMT solver specific class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H +#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSOLVER_H + +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" + +namespace clang { +namespace ento { + +class SMTSolver { +public: + SMTSolver() = default; + virtual ~SMTSolver() = default; + + LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } + + // Return an appropriate sort, given a QualType + SMTSort mkSort(const QualType &Ty, unsigned BitWidth) { + if (Ty->isBooleanType()) + return getBoolSort(); + + if (Ty->isRealFloatingType()) + return getFloatSort(BitWidth); + + return getBitvectorSort(BitWidth); + } + + // Return a boolean sort. + virtual SMTSort getBoolSort() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return an appropriate bitvector sort for the given bitwidth. + virtual SMTSort getBitvectorSort(unsigned BitWidth) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return an appropriate floating-point sort for the given bitwidth. + SMTSort getFloatSort(unsigned BitWidth) { + switch (BitWidth) { + case 16: + return getFloat16Sort(); + case 32: + return getFloat32Sort(); + case 64: + return getFloat64Sort(); + case 128: + return getFloat128Sort(); + default:; + } + llvm_unreachable("Unsupported floating-point bitwidth!"); + } + + // Return a floating-point sort of width 16 + virtual SMTSort getFloat16Sort() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return a floating-point sort of width 32 + virtual SMTSort getFloat32Sort() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return a floating-point sort of width 64 + virtual SMTSort getFloat64Sort() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return a floating-point sort of width 128 + virtual SMTSort getFloat128Sort() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return an appropriate sort for the given AST. + virtual SMTSort getSort(SMTExpr AST) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Given a constraint, add it to the solver + virtual void addConstraint(const SMTExpr &Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector addition operation + virtual SMTExpr mkBVAdd(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector subtraction operation + virtual SMTExpr mkBVSub(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector multiplication operation + virtual SMTExpr mkBVMul(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed modulus operation + virtual SMTExpr mkBVSRem(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned modulus operation + virtual SMTExpr mkBVURem(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed division operation + virtual SMTExpr mkBVSDiv(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned division operation + virtual SMTExpr mkBVUDiv(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector logical shift left operation + virtual SMTExpr mkBVShl(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector arithmetic shift right operation + virtual SMTExpr mkBVAshr(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector logical shift right operation + virtual SMTExpr mkBVLshr(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector negation operation + virtual SMTExpr mkBVNeg(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector not operation + virtual SMTExpr mkBVNot(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector xor operation + virtual SMTExpr mkBVXor(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector or operation + virtual SMTExpr mkBVOr(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector and operation + virtual SMTExpr mkBVAnd(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned less-than operation + virtual SMTExpr mkBVUlt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed less-than operation + virtual SMTExpr mkBVSlt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned greater-than operation + virtual SMTExpr mkBVUgt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed greater-than operation + virtual SMTExpr mkBVSgt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned less-equal-than operation + virtual SMTExpr mkBVUle(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed less-equal-than operation + virtual SMTExpr mkBVSle(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector unsigned greater-equal-than operation + virtual SMTExpr mkBVUge(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector signed greater-equal-than operation + virtual SMTExpr mkBVSge(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a boolean not operation + virtual SMTExpr mkNot(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Create a bitvector equality operation + virtual SMTExpr mkEqual(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkAnd(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkOr(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkIte(SMTExpr Cond, SMTExpr T, SMTExpr F) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkSignExt(unsigned i, SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkZeroExt(unsigned i, SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkExtract(unsigned High, unsigned Low, SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkConcat(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPNeg(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPIsInfinite(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPIsNaN(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPIsNormal(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPIsZero(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPMul(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPDiv(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPRem(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPAdd(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPSub(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPLt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPGt(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPLe(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPGe(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPEqual(SMTExpr LHS, SMTExpr RHS) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPtoFP(SMTExpr From, SMTSort To) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPtoSBV(SMTExpr From, SMTSort To) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkFPtoUBV(SMTExpr From, SMTSort To) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkSBVtoFP(SMTExpr From, unsigned ToWidth) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkUBVtoFP(SMTExpr From, unsigned ToWidth) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual SMTExpr mkSymbol(const char *Name, SMTSort Sort) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + // Return an appropriate floating-point rounding mode. + virtual SMTExpr getFloatRoundingMode() { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual const llvm::APSInt getBitvector(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual bool getBoolean(SMTExpr Exp) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Construct a SMTExpr from a boolean. + virtual SMTExpr mkBoolean(const bool b) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Construct a SMTExpr from a finite APFloat. + virtual SMTExpr mkFloat(const llvm::APFloat Float) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Construct a SMTExpr from an APSInt. + virtual SMTExpr mkBitvector(const llvm::APSInt Int, unsigned BitWidth) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + SMTExpr mkBitvector(const llvm::APSInt Int) { + return mkBitvector(Int, Int.getBitWidth()); + } + + /// Given an expression, extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExpr &Exp, llvm::APSInt &Int) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Given an expression extract the value of this operand in the model. + virtual bool getInterpretation(const SMTExpr &Exp, llvm::APFloat &Float) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Construct a SMTExpr from a SymbolData, given a SMT_context. + virtual SMTExpr fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Check if the constraints are satisfiable + virtual bool check() const { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Push the current solver state + virtual void push() { llvm::report_fatal_error("Unimplemented SMT method"); } + + /// Pop the previous solver state + virtual void pop(unsigned NumStates = 1) { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + /// Reset the solver and remove all constraints. + virtual void reset() const { + llvm::report_fatal_error("Unimplemented SMT method", false); + } + + virtual void print(raw_ostream &OS) const { + llvm::report_fatal_error("Unimplemented SMT method", false); + } +}; + +} // namespace ento +} // namespace clang + +#endif Index: include/clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h =================================================================== --- /dev/null +++ include/clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h @@ -0,0 +1,89 @@ +//== SMTSort.h --------------------------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a SMT generic Sort API, which will be the base class +// for every SMT solver sort specific class. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSORT_H +#define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_SMTSORT_H + +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" + +namespace clang { +namespace ento { + +class SMTSort { +public: + SMTSort() = default; + virtual ~SMTSort() = default; + + virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); } + virtual bool isFloatSort() const { return isFloatSortImpl(); } + virtual bool isBooleanSort() const { return isBooleanSortImpl(); } + + virtual unsigned getBitvectorSortSize() const { + assert(isBitvectorSort() && "Not a bitvector sort!"); + unsigned Size = getBitvectorSortSizeImpl(); + assert(Size && "Size is zero!"); + return Size; + }; + + virtual unsigned getFloatSortSize() const { + assert(isFloatSort() && "Not a floating-point sort!"); + unsigned Size = getFloatSortSizeImpl(); + assert(Size && "Size is zero!"); + return Size; + }; + + friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) { + return LHS.equal_to(RHS); + } + +protected: + virtual bool equal_to(SMTSort const &other) const { return false; } + + virtual bool isBitvectorSortImpl() const { return false; } + virtual bool isFloatSortImpl() const { return false; } + virtual bool isBooleanSortImpl() const { return false; } + + virtual unsigned getBitvectorSortSizeImpl() const { return 0; } + virtual unsigned getFloatSortSizeImpl() const { return 0; } +}; + +template class SMTSolverSort : public SMTSort { +public: + SMTSolverSort(SMTContext &C, SSort S) : SMTSort(), Context(C), Sort(S) {} + virtual ~SMTSolverSort() = default; + + SSort getSort() const { return Sort; } + + virtual bool isBitvectorSortImpl() const = 0; + virtual bool isFloatSortImpl() const = 0; + virtual bool isBooleanSortImpl() const = 0; + + virtual unsigned getBitvectorSortSizeImpl() const = 0; + virtual unsigned getFloatSortSizeImpl() const = 0; + + virtual bool equal_to(SMTSort const &other) const = 0; + + virtual void print(raw_ostream &OS) const = 0; + + LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } + +protected: + SMTContext &Context; + SSort Sort; +}; + +} // namespace ento +} // namespace clang + +#endif Index: lib/StaticAnalyzer/Core/CMakeLists.txt =================================================================== --- lib/StaticAnalyzer/Core/CMakeLists.txt +++ lib/StaticAnalyzer/Core/CMakeLists.txt @@ -48,6 +48,7 @@ SVals.cpp SimpleConstraintManager.cpp SimpleSValBuilder.cpp + SMTConstraintManager.cpp Store.cpp SubEngine.cpp SymbolManager.cpp Index: lib/StaticAnalyzer/Core/SMTConstraintManager.cpp =================================================================== --- /dev/null +++ lib/StaticAnalyzer/Core/SMTConstraintManager.cpp @@ -0,0 +1,841 @@ +//== SMTConstraintManager.cpp -----------------------------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h" +#include "clang/Basic/TargetInfo.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/ExprEngine.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/ProgramState.h" + +using namespace clang; +using namespace ento; + +typedef llvm::ImmutableSet> ConstraintSMTTy; +REGISTER_TRAIT_WITH_PROGRAMSTATE(ConstraintSMT, ConstraintSMTTy) + +clang::ento::ConditionTruthVal SMTConstraintManager::isModelFeasible() { + if (!Solver.check()) + return false; + + return ConditionTruthVal(); +} + +SMTExpr SMTConstraintManager::fromData(const SymbolID ID, const QualType &Ty, + uint64_t BitWidth) const { + llvm::Twine Name = "$" + llvm::Twine(ID); + return Solver.mkSymbol(Name.str().c_str(), Solver.mkSort(Ty, BitWidth)); +} + +SMTExpr SMTConstraintManager::fromUnOp(const UnaryOperator::Opcode Op, + const SMTExpr &Exp) const { + switch (Op) { + case UO_Minus: + return Solver.mkBVNeg(Exp); + + case UO_Not: + return Solver.mkBVNot(Exp); + + case UO_LNot: + return Solver.mkNot(Exp); + + default:; + } + llvm_unreachable("Unimplemented opcode"); +} + +SMTExpr SMTConstraintManager::fromFloatUnOp(const UnaryOperator::Opcode Op, + const SMTExpr &Exp) const { + switch (Op) { + case UO_Minus: + return Solver.mkFPNeg(Exp); + + case UO_LNot: + return fromUnOp(Op, Exp); + + default:; + } + llvm_unreachable("Unimplemented opcode"); +} + +SMTExpr +SMTConstraintManager::fromNBinOp(const BinaryOperator::Opcode Op, + const std::vector &ASTs) const { + assert(!ASTs.empty()); + + if (Op != BO_LAnd && Op != BO_LOr) + llvm_unreachable("Unimplemented opcode"); + + SMTExpr res = ASTs.front(); + for (std::size_t i = 0; i < ASTs.size(); ++i) + res = (Op == BO_LAnd) ? Solver.mkAnd(res, ASTs[i]) + : Solver.mkOr(res, ASTs[i]); + return res; +} + +SMTExpr SMTConstraintManager::fromBinOp(const SMTExpr &LHS, + const BinaryOperator::Opcode Op, + const SMTExpr &RHS, + bool isSigned) const { + assert(Solver.getSort(LHS) == Solver.getSort(RHS) && + "AST's must have the same sort!"); + + switch (Op) { + // Multiplicative operators + case BO_Mul: + return Solver.mkBVMul(LHS, RHS); + + case BO_Div: + return isSigned ? Solver.mkBVSDiv(LHS, RHS) : Solver.mkBVUDiv(LHS, RHS); + + case BO_Rem: + return isSigned ? Solver.mkBVSRem(LHS, RHS) : Solver.mkBVURem(LHS, RHS); + + // Additive operators + case BO_Add: + return Solver.mkBVAdd(LHS, RHS); + + case BO_Sub: + return Solver.mkBVSub(LHS, RHS); + + // Bitwise shift operators + case BO_Shl: + return Solver.mkBVShl(LHS, RHS); + + case BO_Shr: + return isSigned ? Solver.mkBVAshr(LHS, RHS) : Solver.mkBVLshr(LHS, RHS); + + // Relational operators + case BO_LT: + return isSigned ? Solver.mkBVSlt(LHS, RHS) : Solver.mkBVUlt(LHS, RHS); + + case BO_GT: + return isSigned ? Solver.mkBVSgt(LHS, RHS) : Solver.mkBVUgt(LHS, RHS); + + case BO_LE: + return isSigned ? Solver.mkBVSle(LHS, RHS) : Solver.mkBVUle(LHS, RHS); + + case BO_GE: + return isSigned ? Solver.mkBVSge(LHS, RHS) : Solver.mkBVUge(LHS, RHS); + + // Equality operators + case BO_EQ: + return Solver.mkEqual(LHS, RHS); + + case BO_NE: + return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned)); + + // Bitwise operators + case BO_And: + return Solver.mkBVAnd(LHS, RHS); + + case BO_Xor: + return Solver.mkBVXor(LHS, RHS); + + case BO_Or: + return Solver.mkBVOr(LHS, RHS); + + // Logical operators + case BO_LAnd: + return Solver.mkAnd(LHS, RHS); + + case BO_LOr: + return Solver.mkOr(LHS, RHS); + + default:; + } + llvm_unreachable("Unimplemented opcode"); +} + +SMTExpr SMTConstraintManager::fromFloatSpecialBinOp( + const SMTExpr &LHS, const BinaryOperator::Opcode Op, + const llvm::APFloat::fltCategory &RHS) const { + switch (Op) { + // Equality operators + case BO_EQ: + switch (RHS) { + case llvm::APFloat::fcInfinity: + return Solver.mkFPIsInfinite(LHS); + + case llvm::APFloat::fcNaN: + return Solver.mkFPIsNaN(LHS); + + case llvm::APFloat::fcNormal: + return Solver.mkFPIsNormal(LHS); + + case llvm::APFloat::fcZero: + return Solver.mkFPIsZero(LHS); + } + break; + + case BO_NE: + return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS)); + + default:; + } + + llvm_unreachable("Unimplemented opcode"); +} + +SMTExpr SMTConstraintManager::fromFloatBinOp(const SMTExpr &LHS, + const BinaryOperator::Opcode Op, + const SMTExpr &RHS) const { + assert(Solver.getSort(LHS) == Solver.getSort(RHS) && + "AST's must have the same sort!"); + + switch (Op) { + // Multiplicative operators + case BO_Mul: + return Solver.mkFPMul(LHS, RHS); + + case BO_Div: + return Solver.mkFPDiv(LHS, RHS); + + case BO_Rem: + return Solver.mkFPRem(LHS, RHS); + + // Additive operators + case BO_Add: + return Solver.mkFPAdd(LHS, RHS); + + case BO_Sub: + return Solver.mkFPSub(LHS, RHS); + + // Relational operators + case BO_LT: + return Solver.mkFPLt(LHS, RHS); + + case BO_GT: + return Solver.mkFPGt(LHS, RHS); + + case BO_LE: + return Solver.mkFPLe(LHS, RHS); + + case BO_GE: + return Solver.mkFPGe(LHS, RHS); + + // Equality operators + case BO_EQ: + return Solver.mkFPEqual(LHS, RHS); + + case BO_NE: + return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS)); + + // Logical operators + case BO_LAnd: + case BO_LOr: + return fromBinOp(LHS, Op, RHS, false); + + default:; + } + + llvm_unreachable("Unimplemented opcode"); +} + +llvm::APSInt SMTConstraintManager::castAPSInt(const llvm::APSInt &V, + QualType ToTy, uint64_t ToWidth, + QualType FromTy, + uint64_t FromWidth) const { + APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType()); + return TargetType.convert(V); +} + +SMTExpr SMTConstraintManager::fromCast(const SMTExpr &Exp, QualType ToTy, + uint64_t ToBitWidth, QualType FromTy, + uint64_t FromBitWidth) const { + if ((FromTy->isIntegralOrEnumerationType() && + ToTy->isIntegralOrEnumerationType()) || + (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) || + (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) || + (FromTy->isReferenceType() ^ ToTy->isReferenceType())) { + + if (FromTy->isBooleanType()) { + assert(ToBitWidth > 0 && "BitWidth must be positive!"); + return Solver.mkIte(Exp, + Solver.mkBitvector(llvm::APSInt("1"), ToBitWidth), + Solver.mkBitvector(llvm::APSInt("0"), ToBitWidth)); + } + + if (ToBitWidth > FromBitWidth) + return FromTy->isSignedIntegerOrEnumerationType() + ? Solver.mkSignExt(ToBitWidth - FromBitWidth, Exp) + : Solver.mkZeroExt(ToBitWidth - FromBitWidth, Exp); + + if (ToBitWidth < FromBitWidth) + return Solver.mkExtract(ToBitWidth - 1, 0, Exp); + + // Both are bitvectors with the same width, ignore the type cast + return Exp; + } + + if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) { + if (ToBitWidth != FromBitWidth) + return Solver.mkFPtoFP(Exp, Solver.getFloatSort(ToBitWidth)); + + return Exp; + } + + if (FromTy->isIntegralOrEnumerationType() && ToTy->isRealFloatingType()) { + SMTSort Sort = Solver.getFloatSort(ToBitWidth); + return FromTy->isSignedIntegerOrEnumerationType() + ? Solver.mkFPtoSBV(Exp, Sort) + : Solver.mkFPtoUBV(Exp, Sort); + } + + if (FromTy->isRealFloatingType() && ToTy->isIntegralOrEnumerationType()) + return ToTy->isSignedIntegerOrEnumerationType() + ? Solver.mkSBVtoFP(Exp, ToBitWidth) + : Solver.mkUBVtoFP(Exp, ToBitWidth); + + llvm_unreachable("Unsupported explicit type cast!"); +} + +ProgramStateRef +SMTConstraintManager::assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym, + bool Assumption) { + // Skip anything that is unsupported + return State; +} + +void SMTConstraintManager::addStateConstraints(ProgramStateRef State) const { + // TODO: Don't add all the constraints, only the relevant ones + ConstraintSMTTy CZ = State->get(); + ConstraintSMTTy::iterator I = CZ.begin(), IE = CZ.end(); + + // Construct the logical AND of all the constraints + if (I != IE) { + std::vector ASTs; + + while (I != IE) + ASTs.push_back(I++->second); + + Solver.addConstraint(fromNBinOp(BO_LAnd, ASTs)); + } +} + +ProgramStateRef SMTConstraintManager::assumeSMTExpr(ProgramStateRef State, + SymbolRef Sym, + const SMTExpr &Exp) { + // Check the model, avoid simplifying AST to save time + if (checkModel(State, Exp)) + return State->add(std::make_pair(Sym, Exp)); + + return nullptr; +} + +bool SMTConstraintManager::checkModel(ProgramStateRef State, + const SMTExpr &Exp) { + Solver.reset(); + Solver.addConstraint(Exp); + addStateConstraints(State); + return Solver.check(); +} + +ProgramStateRef SMTConstraintManager::assumeSym(ProgramStateRef State, + SymbolRef Sym, + bool Assumption) { + QualType RetTy; + bool hasComparison; + + SMTExpr Exp = getSMTExpr(Sym, &RetTy, &hasComparison); + + // Create zero comparison for implicit boolean cast, with reversed assumption + if (!hasComparison && !RetTy->isBooleanType()) + return assumeSMTExpr(State, Sym, getSMTZeroExpr(Exp, RetTy, !Assumption)); + + return assumeSMTExpr(State, Sym, Assumption ? Exp : getSMTNotExpr(Exp)); +} + +ProgramStateRef SMTConstraintManager::assumeSymInclusiveRange( + ProgramStateRef State, SymbolRef Sym, const llvm::APSInt &From, + const llvm::APSInt &To, bool InRange) { + return assumeSMTExpr(State, Sym, getSMTRangeExpr(Sym, From, To, InRange)); +} + +bool SMTConstraintManager::canReasonAbout(SVal X) const { + const TargetInfo &TI = getBasicVals().getContext().getTargetInfo(); + + Optional SymVal = X.getAs(); + if (!SymVal) + return true; + + const SymExpr *Sym = SymVal->getSymbol(); + QualType Ty = Sym->getType(); + + // Complex types are not modeled + if (Ty->isComplexType() || Ty->isComplexIntegerType()) + return false; + + // Non-IEEE 754 floating-point types are not modeled + if ((Ty->isSpecificBuiltinType(BuiltinType::LongDouble) && + (&TI.getLongDoubleFormat() == &llvm::APFloat::x87DoubleExtended() || + &TI.getLongDoubleFormat() == &llvm::APFloat::PPCDoubleDouble()))) + return false; + + if (isa(Sym)) + return true; + + SValBuilder &SVB = getSValBuilder(); + + if (const SymbolCast *SC = dyn_cast(Sym)) + return canReasonAbout(SVB.makeSymbolVal(SC->getOperand())); + + if (const BinarySymExpr *BSE = dyn_cast(Sym)) { + if (const SymIntExpr *SIE = dyn_cast(BSE)) + return canReasonAbout(SVB.makeSymbolVal(SIE->getLHS())); + + if (const IntSymExpr *ISE = dyn_cast(BSE)) + return canReasonAbout(SVB.makeSymbolVal(ISE->getRHS())); + + if (const SymSymExpr *SSE = dyn_cast(BSE)) + return canReasonAbout(SVB.makeSymbolVal(SSE->getLHS())) && + canReasonAbout(SVB.makeSymbolVal(SSE->getRHS())); + } + + llvm_unreachable("Unsupported expression to reason about!"); +} + +ProgramStateRef +SMTConstraintManager::removeDeadBindings(ProgramStateRef State, + SymbolReaper &SymReaper) { + ConstraintSMTTy CZ = State->get(); + ConstraintSMTTy::Factory &CZFactory = State->get_context(); + + for (ConstraintSMTTy::iterator I = CZ.begin(), E = CZ.end(); I != E; ++I) { + if (SymReaper.maybeDead(I->first)) + CZ = CZFactory.remove(CZ, *I); + } + + return State->set(CZ); +} + +void SMTConstraintManager::print(ProgramStateRef St, raw_ostream &OS, + const char *nl, const char *sep) { + + ConstraintSMTTy CZ = St->get(); + + OS << nl << sep << "Constraints:"; + for (ConstraintSMTTy::iterator I = CZ.begin(), E = CZ.end(); I != E; ++I) { + OS << nl << ' ' << I->first << " : "; + I->second.print(OS); + } + OS << nl; +} + +ConditionTruthVal SMTConstraintManager::checkNull(ProgramStateRef State, + SymbolRef Sym) { + QualType RetTy; + + // The expression may be casted, so we cannot call getSMTDataExpr() directly + SMTExpr VarExp = getSMTExpr(Sym, &RetTy); + SMTExpr Exp = getSMTZeroExpr(VarExp, RetTy, true); + + // Negate the constraint + SMTExpr NotExp = getSMTZeroExpr(VarExp, RetTy, false); + + Solver.reset(); + addStateConstraints(State); + + Solver.push(); + Solver.addConstraint(Exp); + bool isSat = Solver.check(); + + Solver.pop(); + Solver.addConstraint(NotExp); + bool isNotSat = Solver.check(); + + // Zero is the only possible solution + if (isSat && !isNotSat) + return true; + + // Zero is not a solution + if (!isSat && isNotSat) + return false; + + // Zero may be a solution + return ConditionTruthVal(); +} + +const llvm::APSInt *SMTConstraintManager::getSymVal(ProgramStateRef State, + SymbolRef Sym) const { + BasicValueFactory &BVF = getBasicVals(); + ASTContext &Ctx = BVF.getContext(); + + if (const SymbolData *SD = dyn_cast(Sym)) { + QualType Ty = Sym->getType(); + assert(!Ty->isRealFloatingType()); + llvm::APSInt Value(Ctx.getTypeSize(Ty), + !Ty->isSignedIntegerOrEnumerationType()); + + SMTExpr Exp = getSMTDataExpr(SD->getSymbolID(), Ty); + + Solver.reset(); + addStateConstraints(State); + + // Constraints are unsatisfiable + if (Solver.check()) + return nullptr; + + // Model does not assign interpretation + if (!Solver.getInterpretation(Exp, Value)) + return nullptr; + + // A value has been obtained, check if it is the only value + SMTExpr NotExp = fromBinOp( + Exp, BO_NE, + Ty->isBooleanType() ? Solver.mkBoolean(Value.getBoolValue()) + : Solver.mkBitvector(Value, Value.getBitWidth()), + false); + + Solver.addConstraint(NotExp); + if (Solver.check()) + return nullptr; + + // This is the only solution, store it + return &BVF.getValue(Value); + } + + if (const SymbolCast *SC = dyn_cast(Sym)) { + SymbolRef CastSym = SC->getOperand(); + QualType CastTy = SC->getType(); + // Skip the void type + if (CastTy->isVoidType()) + return nullptr; + + const llvm::APSInt *Value; + if (!(Value = getSymVal(State, CastSym))) + return nullptr; + return &BVF.Convert(SC->getType(), *Value); + } + + if (const BinarySymExpr *BSE = dyn_cast(Sym)) { + const llvm::APSInt *LHS, *RHS; + if (const SymIntExpr *SIE = dyn_cast(BSE)) { + LHS = getSymVal(State, SIE->getLHS()); + RHS = &SIE->getRHS(); + } else if (const IntSymExpr *ISE = dyn_cast(BSE)) { + LHS = &ISE->getLHS(); + RHS = getSymVal(State, ISE->getRHS()); + } else if (const SymSymExpr *SSM = dyn_cast(BSE)) { + // Early termination to avoid expensive call + LHS = getSymVal(State, SSM->getLHS()); + RHS = LHS ? getSymVal(State, SSM->getRHS()) : nullptr; + } else { + llvm_unreachable("Unsupported binary expression to get symbol value!"); + } + + if (!LHS || !RHS) + return nullptr; + + llvm::APSInt ConvertedLHS, ConvertedRHS; + QualType LTy, RTy; + std::tie(ConvertedLHS, LTy) = fixAPSInt(*LHS); + std::tie(ConvertedRHS, RTy) = fixAPSInt(*RHS); + doIntTypeConversion( + ConvertedLHS, LTy, ConvertedRHS, RTy); + return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS); + } + + llvm_unreachable("Unsupported expression to get symbol value!"); +} + +void SMTConstraintManager::addRangeConstraints( + clang::ento::ConstraintRangeTy CR) { + for (const auto &I : CR) { + SymbolRef Sym = I.first; + + SMTExpr Constraints = Solver.mkBoolean(false); + for (const auto &Range : I.second) { + Constraints = Solver.mkOr(Constraints, + getSMTRangeExpr(Sym, Range.From(), Range.To(), + /*InRange=*/true)); + } + Solver.addConstraint(Constraints); + } +} + +SMTExpr SMTConstraintManager::getSMTExpr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const { + if (hasComparison) { + *hasComparison = false; + } + + return getSMTSymExpr(Sym, RetTy, hasComparison); +} + +SMTExpr SMTConstraintManager::getSMTNotExpr(const SMTExpr &Exp) const { + return fromUnOp(UO_LNot, Exp); +} + +SMTExpr SMTConstraintManager::getSMTZeroExpr(const SMTExpr &Exp, QualType Ty, + bool Assumption) const { + ASTContext &Ctx = getBasicVals().getContext(); + if (Ty->isRealFloatingType()) { + llvm::APFloat Zero = llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty)); + return fromFloatBinOp(Exp, Assumption ? BO_EQ : BO_NE, + Solver.mkFloat(Zero)); + } + + if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() || + Ty->isBlockPointerType() || Ty->isReferenceType()) { + bool isSigned = Ty->isSignedIntegerOrEnumerationType(); + + // Skip explicit comparison for boolean types + if (Ty->isBooleanType()) + return Assumption ? getSMTNotExpr(Exp) : Exp; + + return fromBinOp(Exp, Assumption ? BO_EQ : BO_NE, + Solver.mkBitvector(llvm::APSInt("0"), Ctx.getTypeSize(Ty)), + isSigned); + } + + llvm_unreachable("Unsupported type for zero value!"); +} + +SMTExpr SMTConstraintManager::getSMTSymExpr(SymbolRef Sym, QualType *RetTy, + bool *hasComparison) const { + if (const SymbolData *SD = dyn_cast(Sym)) { + if (RetTy) + *RetTy = Sym->getType(); + + return getSMTDataExpr(SD->getSymbolID(), Sym->getType()); + } + + if (const SymbolCast *SC = dyn_cast(Sym)) { + if (RetTy) + *RetTy = Sym->getType(); + + QualType FromTy; + SMTExpr Exp = getSMTSymExpr(SC->getOperand(), &FromTy, hasComparison); + // Casting an expression with a comparison invalidates it. Note that this + // must occur after the recursive call above. + // e.g. (signed char) (x > 0) + if (hasComparison) + *hasComparison = false; + return getSMTCastExpr(Exp, FromTy, Sym->getType()); + } + + if (const BinarySymExpr *BSE = dyn_cast(Sym)) { + SMTExpr Exp = getSMTSymBinExpr(BSE, hasComparison, RetTy); + // Set the hasComparison parameter, in post-order traversal order. + if (hasComparison) + *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode()); + return Exp; + } + + llvm_unreachable("Unsupported SymbolRef type!"); +} + +SMTExpr SMTConstraintManager::getSMTDataExpr(const SymbolID ID, + QualType Ty) const { + ASTContext &Ctx = getBasicVals().getContext(); + return fromData(ID, Ty, Ctx.getTypeSize(Ty)); +} + +SMTExpr SMTConstraintManager::getSMTCastExpr(const SMTExpr &Exp, + QualType FromTy, + QualType ToTy) const { + ASTContext &Ctx = getBasicVals().getContext(); + return fromCast(Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy, + Ctx.getTypeSize(FromTy)); +} + +SMTExpr SMTConstraintManager::getSMTSymBinExpr(const BinarySymExpr *BSE, + bool *hasComparison, + QualType *RetTy) const { + QualType LTy, RTy; + BinaryOperator::Opcode Op = BSE->getOpcode(); + + if (const SymIntExpr *SIE = dyn_cast(BSE)) { + SMTExpr LHS = getSMTSymExpr(SIE->getLHS(), <y, hasComparison); + llvm::APSInt NewRInt; + std::tie(NewRInt, RTy) = fixAPSInt(SIE->getRHS()); + SMTExpr RHS = Solver.mkBitvector(NewRInt); + return getSMTBinExpr(LHS, LTy, Op, RHS, RTy, RetTy); + } + + if (const IntSymExpr *ISE = dyn_cast(BSE)) { + llvm::APSInt NewLInt; + std::tie(NewLInt, LTy) = fixAPSInt(ISE->getLHS()); + SMTExpr LHS = Solver.mkBitvector(NewLInt); + SMTExpr RHS = getSMTSymExpr(ISE->getRHS(), &RTy, hasComparison); + return getSMTBinExpr(LHS, LTy, Op, RHS, RTy, RetTy); + } + + if (const SymSymExpr *SSM = dyn_cast(BSE)) { + SMTExpr LHS = getSMTSymExpr(SSM->getLHS(), <y, hasComparison); + SMTExpr RHS = getSMTSymExpr(SSM->getRHS(), &RTy, hasComparison); + return getSMTBinExpr(LHS, LTy, Op, RHS, RTy, RetTy); + } + + llvm_unreachable("Unsupported BinarySymExpr type!"); +} + +SMTExpr SMTConstraintManager::getSMTBinExpr(const SMTExpr &LHS, QualType LTy, + BinaryOperator::Opcode Op, + const SMTExpr &RHS, QualType RTy, + QualType *RetTy) const { + SMTExpr NewLHS = LHS; + SMTExpr NewRHS = RHS; + doTypeConversion(NewLHS, NewRHS, LTy, RTy); + + // Update the return type parameter if the output type has changed. + if (RetTy) { + // A boolean result can be represented as an integer type in C/C++, but at + // this point we only care about the SMT type. Set it as a boolean type to + // avoid subsequent SMT errors. + if (BinaryOperator::isComparisonOp(Op) || BinaryOperator::isLogicalOp(Op)) { + ASTContext &Ctx = getBasicVals().getContext(); + *RetTy = Ctx.BoolTy; + } else { + *RetTy = LTy; + } + + // If the two operands are pointers and the operation is a subtraction, the + // result is of type ptrdiff_t, which is signed + if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) { + *RetTy = getBasicVals().getContext().getPointerDiffType(); + } + } + + return LTy->isRealFloatingType() + ? fromFloatBinOp(NewLHS, Op, NewRHS) + : fromBinOp(NewLHS, Op, NewRHS, + LTy->isSignedIntegerOrEnumerationType()); +} + +SMTExpr SMTConstraintManager::getSMTRangeExpr(SymbolRef Sym, + const llvm::APSInt &From, + const llvm::APSInt &To, + bool InRange) { + // Convert lower bound + QualType FromTy; + llvm::APSInt NewFromInt; + std::tie(NewFromInt, FromTy) = fixAPSInt(From); + SMTExpr FromExp = Solver.mkBitvector(NewFromInt); + + // Convert symbol + QualType SymTy; + SMTExpr Exp = getSMTExpr(Sym, &SymTy); + + // Construct single (in)equality + if (From == To) + return getSMTBinExpr(Exp, SymTy, InRange ? BO_EQ : BO_NE, FromExp, FromTy, + /*RetTy=*/nullptr); + + QualType ToTy; + llvm::APSInt NewToInt; + std::tie(NewToInt, ToTy) = fixAPSInt(To); + SMTExpr ToExp = Solver.mkBitvector(NewToInt); + assert(FromTy == ToTy && "Range values have different types!"); + + // Construct two (in)equalities, and a logical and/or + SMTExpr LHS = getSMTBinExpr(Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp, + FromTy, /*RetTy=*/nullptr); + SMTExpr RHS = getSMTBinExpr(Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy, + /*RetTy=*/nullptr); + + return fromBinOp(LHS, InRange ? BO_LAnd : BO_LOr, RHS, + SymTy->isSignedIntegerOrEnumerationType()); +} + +//===------------------------------------------------------------------===// +// Helper functions. +//===------------------------------------------------------------------===// + +QualType SMTConstraintManager::getAPSIntType(const llvm::APSInt &Int) const { + ASTContext &Ctx = getBasicVals().getContext(); + return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned()); +} + +std::pair +SMTConstraintManager::fixAPSInt(const llvm::APSInt &Int) const { + llvm::APSInt NewInt; + + // FIXME: This should be a cast from a 1-bit integer type to a boolean type, + // but the former is not available in Clang. Instead, extend the APSInt + // directly. + if (Int.getBitWidth() == 1 && getAPSIntType(Int).isNull()) { + ASTContext &Ctx = getBasicVals().getContext(); + NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy)); + } else + NewInt = Int; + + return std::make_pair(NewInt, getAPSIntType(NewInt)); +} + +void SMTConstraintManager::doTypeConversion(SMTExpr &LHS, SMTExpr &RHS, + QualType <y, + QualType &RTy) const { + ASTContext &Ctx = getBasicVals().getContext(); + + assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!"); + // Perform type conversion + if (LTy->isIntegralOrEnumerationType() && + RTy->isIntegralOrEnumerationType()) { + if (LTy->isArithmeticType() && RTy->isArithmeticType()) { + doIntTypeConversion(LHS, LTy, + RHS, RTy); + return; + } + } + + if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) { + doFloatTypeConversion(LHS, LTy, + RHS, RTy); + return; + } + + if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) || + (LTy->isBlockPointerType() || RTy->isBlockPointerType()) || + (LTy->isReferenceType() || RTy->isReferenceType())) { + // TODO: Refactor to Sema::FindCompositePointerType(), and + // Sema::CheckCompareOperands(). + + uint64_t LBitWidth = Ctx.getTypeSize(LTy); + uint64_t RBitWidth = Ctx.getTypeSize(RTy); + + // Cast the non-pointer type to the pointer type. + // TODO: Be more strict about this. + if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) || + (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) || + (LTy->isReferenceType() ^ RTy->isReferenceType())) { + if (LTy->isNullPtrType() || LTy->isBlockPointerType() || + LTy->isReferenceType()) { + LHS = fromCast(LHS, RTy, RBitWidth, LTy, LBitWidth); + LTy = RTy; + } else { + RHS = fromCast(RHS, LTy, LBitWidth, RTy, RBitWidth); + RTy = LTy; + } + } + + // Cast the void pointer type to the non-void pointer type. + // For void types, this assumes that the casted value is equal to the value + // of the original pointer, and does not account for alignment requirements. + if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) { + assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) && + "Pointer types have different bitwidths!"); + if (RTy->isVoidPointerType()) + RTy = LTy; + else + LTy = RTy; + } + + if (LTy == RTy) + return; + } + + // Fallback: for the solver, assume that these types don't really matter + if ((LTy.getCanonicalType() == RTy.getCanonicalType()) || + (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) { + LTy = RTy; + return; + } + + // TODO: Refine behavior for invalid type casts +} Index: lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp =================================================================== --- lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp +++ lib/StaticAnalyzer/Core/Z3ConstraintManager.cpp @@ -12,6 +12,8 @@ #include "clang/StaticAnalyzer/Core/PathSensitive/ProgramState.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SMTContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTExpr.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/SMTSort.h" #include "clang/Config/config.h" @@ -22,29 +24,12 @@ #include -// Forward declarations namespace { -class Z3Expr; -class ConstraintZ3 {}; -} // end anonymous namespace - -typedef llvm::ImmutableSet> ConstraintZ3Ty; - -// Expansion of REGISTER_TRAIT_WITH_PROGRAMSTATE(ConstraintZ3, Z3SetPair) -namespace clang { -namespace ento { -template <> -struct ProgramStateTrait - : public ProgramStatePartialTrait { - static void *GDMIndex() { - static int Index; - return &Index; - } -}; -} // end namespace ento -} // end namespace clang -namespace { +class Z3Expr; +class Z3Sort; +static const Z3Expr *toZ3Expr(const SMTExpr &E); +static const Z3Sort *toZ3Sort(const SMTSort &S); class Z3Config { friend class Z3Context; @@ -82,27 +67,27 @@ } }; // end class Z3Context -class Z3Sort { - friend class Z3Expr; - friend class Z3Solver; - - Z3Context &Context; +static Z3Context *toZ3Context(SMTContext &C) { + return static_cast(&C); +} - Z3_sort Sort; +class Z3Sort : public SMTSolverSort { + friend class Z3Solver; - Z3Sort(Z3Context &C, Z3_sort ZS) : Context(C), Sort(ZS) { - assert(C.getContext() != nullptr); - Z3_inc_ref(Context.getContext(), reinterpret_cast(Sort)); + Z3Sort(SMTContext &C, Z3_sort ZS) : SMTSolverSort(C, ZS) { + Z3_inc_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Sort)); } public: /// Override implicit copy constructor for correct reference counting. - Z3Sort(const Z3Sort &Copy) : Context(Copy.Context), Sort(Copy.Sort) { - Z3_inc_ref(Context.getContext(), reinterpret_cast(Sort)); + Z3Sort(const Z3Sort &Copy) : SMTSolverSort(Copy.Context, Copy.Sort) { + Z3_inc_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Sort)); } /// Provide move constructor - Z3Sort(Z3Sort &&Move) : Context(Move.Context), Sort(nullptr) { + Z3Sort(Z3Sort &&Move) : SMTSolverSort(Move.Context, nullptr) { *this = std::move(Move); } @@ -110,7 +95,8 @@ Z3Sort &operator=(Z3Sort &&Move) { if (this != &Move) { if (Sort) - Z3_dec_ref(Context.getContext(), reinterpret_cast(Sort)); + Z3_dec_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Sort)); Sort = Move.Sort; Move.Sort = nullptr; } @@ -119,77 +105,72 @@ ~Z3Sort() { if (Sort) - Z3_dec_ref(Context.getContext(), reinterpret_cast(Sort)); + Z3_dec_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Sort)); + } + + bool isBitvectorSortImpl() const override { + return (Z3_get_sort_kind(toZ3Context(Context)->getContext(), Sort) == + Z3_BV_SORT); + } + + bool isFloatSortImpl() const override { + return (Z3_get_sort_kind(toZ3Context(Context)->getContext(), Sort) == + Z3_FLOATING_POINT_SORT); } - Z3_sort_kind getSortKind() const { - return Z3_get_sort_kind(Context.getContext(), Sort); + bool isBooleanSortImpl() const override { + return (Z3_get_sort_kind(toZ3Context(Context)->getContext(), Sort) == + Z3_BOOL_SORT); } - unsigned getBitvectorSortSize() const { - assert(getSortKind() == Z3_BV_SORT && "Not a bitvector sort!"); - return Z3_get_bv_sort_size(Context.getContext(), Sort); + unsigned getBitvectorSortSizeImpl() const override { + return Z3_get_bv_sort_size(toZ3Context(Context)->getContext(), Sort); } - unsigned getFloatSortSize() const { - assert(getSortKind() == Z3_FLOATING_POINT_SORT && - "Not a floating-point sort!"); - return Z3_fpa_get_ebits(Context.getContext(), Sort) + - Z3_fpa_get_sbits(Context.getContext(), Sort); + unsigned getFloatSortSizeImpl() const override { + return Z3_fpa_get_ebits(toZ3Context(Context)->getContext(), Sort) + + Z3_fpa_get_sbits(toZ3Context(Context)->getContext(), Sort); } - bool operator==(const Z3Sort &Other) const { - return Z3_is_eq_sort(Context.getContext(), Sort, Other.Sort); + bool equal_to(SMTSort const &Other) const override { + return Z3_is_eq_sort(toZ3Context(Context)->getContext(), Sort, + static_cast(&Other)->Sort); } Z3Sort &operator=(const Z3Sort &Move) { - Z3_inc_ref(Context.getContext(), reinterpret_cast(Move.Sort)); - Z3_dec_ref(Context.getContext(), reinterpret_cast(Sort)); + Z3_inc_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Move.Sort)); + Z3_dec_ref(toZ3Context(Context)->getContext(), + reinterpret_cast(Sort)); Sort = Move.Sort; return *this; } - void print(raw_ostream &OS) const { - OS << Z3_sort_to_string(Context.getContext(), Sort); + void print(raw_ostream &OS) const override { + OS << Z3_sort_to_string(toZ3Context(Context)->getContext(), Sort); } - - LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } }; // end class Z3Sort -class Z3Expr { - friend class Z3Model; - friend class Z3Solver; - - Z3Context &Context; - - Z3_ast AST; +static const Z3Sort *toZ3Sort(const SMTSort &S) { + return static_cast(&S); +} - Z3Expr(Z3Context &C, Z3_ast ZA) : Context(C), AST(ZA) { - assert(C.getContext() != nullptr); - Z3_inc_ref(Context.getContext(), AST); - } +class Z3Expr : public SMTSolverExpr { + friend class Z3Solver; - // Determine whether two float semantics are equivalent - static bool areEquivalent(const llvm::fltSemantics &LHS, - const llvm::fltSemantics &RHS) { - return (llvm::APFloat::semanticsPrecision(LHS) == - llvm::APFloat::semanticsPrecision(RHS)) && - (llvm::APFloat::semanticsMinExponent(LHS) == - llvm::APFloat::semanticsMinExponent(RHS)) && - (llvm::APFloat::semanticsMaxExponent(LHS) == - llvm::APFloat::semanticsMaxExponent(RHS)) && - (llvm::APFloat::semanticsSizeInBits(LHS) == - llvm::APFloat::semanticsSizeInBits(RHS)); + Z3Expr(SMTContext &C, Z3_ast ZA) : SMTSolverExpr(C, ZA) { + Z3_inc_ref(toZ3Context(Context)->getContext(), AST); } public: /// Override implicit copy constructor for correct reference counting. - Z3Expr(const Z3Expr &Copy) : Context(Copy.Context), AST(Copy.AST) { - Z3_inc_ref(Context.getContext(), AST); + Z3Expr(const Z3Expr &Copy) : SMTSolverExpr(Copy.Context, Copy.AST) { + Z3_inc_ref(toZ3Context(Context)->getContext(), AST); } /// Provide move constructor - Z3Expr(Z3Expr &&Move) : Context(Move.Context), AST(nullptr) { + Z3Expr(Z3Expr &&Move) : SMTSolverExpr(Move.Context, nullptr) { *this = std::move(Move); } @@ -197,7 +178,7 @@ Z3Expr &operator=(Z3Expr &&Move) { if (this != &Move) { if (AST) - Z3_dec_ref(Context.getContext(), AST); + Z3_dec_ref(toZ3Context(Context)->getContext(), AST); AST = Move.AST; Move.AST = nullptr; } @@ -206,77 +187,56 @@ ~Z3Expr() { if (AST) - Z3_dec_ref(Context.getContext(), AST); - } - - /// Get the corresponding IEEE floating-point type for a given bitwidth. - static const llvm::fltSemantics &getFloatSemantics(unsigned BitWidth) { - switch (BitWidth) { - default: - llvm_unreachable("Unsupported floating-point semantics!"); - break; - case 16: - return llvm::APFloat::IEEEhalf(); - case 32: - return llvm::APFloat::IEEEsingle(); - case 64: - return llvm::APFloat::IEEEdouble(); - case 128: - return llvm::APFloat::IEEEquad(); - } - } - - void Profile(llvm::FoldingSetNodeID &ID) const { - ID.AddInteger(Z3_get_ast_hash(Context.getContext(), AST)); + Z3_dec_ref(toZ3Context(Context)->getContext(), AST); } - bool operator<(const Z3Expr &Other) const { - llvm::FoldingSetNodeID ID1, ID2; - Profile(ID1); - Other.Profile(ID2); - return ID1 < ID2; + void Profile(llvm::FoldingSetNodeID &ID) const override { + ID.AddInteger(Z3_get_ast_hash(toZ3Context(Context)->getContext(), AST)); } /// Comparison of AST equality, not model equivalence. - bool operator==(const Z3Expr &Other) const { - assert(Z3_is_eq_sort(Context.getContext(), - Z3_get_sort(Context.getContext(), AST), - Z3_get_sort(Context.getContext(), Other.AST)) && + bool equal_to(SMTExpr const &Other) const override { + assert(Z3_is_eq_sort(toZ3Context(Context)->getContext(), + Z3_get_sort(toZ3Context(Context)->getContext(), AST), + Z3_get_sort(toZ3Context(Context)->getContext(), + toZ3Expr(Other)->getExpr())) && "AST's must have the same sort"); - return Z3_is_eq_ast(Context.getContext(), AST, Other.AST); + return Z3_is_eq_ast(toZ3Context(Context)->getContext(), AST, + toZ3Expr(Other)->getExpr()); } /// Override implicit move constructor for correct reference counting. Z3Expr &operator=(const Z3Expr &Move) { - Z3_inc_ref(Context.getContext(), Move.AST); - Z3_dec_ref(Context.getContext(), AST); + Z3_inc_ref(toZ3Context(Context)->getContext(), Move.AST); + Z3_dec_ref(toZ3Context(Context)->getContext(), AST); AST = Move.AST; return *this; } - void print(raw_ostream &OS) const { - OS << Z3_ast_to_string(Context.getContext(), AST); + void print(raw_ostream &OS) const override { + OS << Z3_ast_to_string(toZ3Context(Context)->getContext(), AST); } - - LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } }; // end class Z3Expr +static const Z3Expr *toZ3Expr(const SMTExpr &E) { + return static_cast(&E); +} + class Z3Model { friend class Z3Solver; - Z3Context &Context; + SMTContext &Context; Z3_model Model; public: - Z3Model(Z3Context &C, Z3_model ZM) : Context(C), Model(ZM) { - assert(C.getContext() != nullptr); - Z3_model_inc_ref(Context.getContext(), Model); + Z3Model(SMTContext &C, Z3_model ZM) : Context(C), Model(ZM) { + Z3_model_inc_ref(toZ3Context(Context)->getContext(), Model); } /// Override implicit copy constructor for correct reference counting. Z3Model(const Z3Model &Copy) : Context(Copy.Context), Model(Copy.Model) { - Z3_model_inc_ref(Context.getContext(), Model); + Z3_model_inc_ref(toZ3Context(Context)->getContext(), Model); } /// Provide move constructor @@ -288,7 +248,7 @@ Z3Model &operator=(Z3Model &&Move) { if (this != &Move) { if (Model) - Z3_model_dec_ref(Context.getContext(), Model); + Z3_model_dec_ref(toZ3Context(Context)->getContext(), Model); Model = Move.Model; Move.Model = nullptr; } @@ -297,35 +257,65 @@ ~Z3Model() { if (Model) - Z3_model_dec_ref(Context.getContext(), Model); + Z3_model_dec_ref(toZ3Context(Context)->getContext(), Model); } void print(raw_ostream &OS) const { - OS << Z3_model_to_string(Context.getContext(), Model); + OS << Z3_model_to_string(toZ3Context(Context)->getContext(), Model); } LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } }; // end class Z3Model -class Z3Solver { +/// Get the corresponding IEEE floating-point type for a given bitwidth. +static const llvm::fltSemantics &getFloatSemantics(unsigned BitWidth) { + switch (BitWidth) { + default: + llvm_unreachable("Unsupported floating-point semantics!"); + break; + case 16: + return llvm::APFloat::IEEEhalf(); + case 32: + return llvm::APFloat::IEEEsingle(); + case 64: + return llvm::APFloat::IEEEdouble(); + case 128: + return llvm::APFloat::IEEEquad(); + } +} + +// Determine whether two float semantics are equivalent +static bool areEquivalent(const llvm::fltSemantics &LHS, + const llvm::fltSemantics &RHS) { + return (llvm::APFloat::semanticsPrecision(LHS) == + llvm::APFloat::semanticsPrecision(RHS)) && + (llvm::APFloat::semanticsMinExponent(LHS) == + llvm::APFloat::semanticsMinExponent(RHS)) && + (llvm::APFloat::semanticsMaxExponent(LHS) == + llvm::APFloat::semanticsMaxExponent(RHS)) && + (llvm::APFloat::semanticsSizeInBits(LHS) == + llvm::APFloat::semanticsSizeInBits(RHS)); +} + +class Z3Solver : public SMTSolver { friend class Z3ConstraintManager; Z3Context Context; Z3_solver Solver; - Z3Solver() : Solver(Z3_mk_simple_solver(Context.getContext())) { - Z3_solver_inc_ref(Context.getContext(), Solver); + Z3Solver() : SMTSolver(), Solver(Z3_mk_simple_solver(Context.getContext())) { + Z3_solver_inc_ref(toZ3Context(Context)->getContext(), Solver); } public: /// Override implicit copy constructor for correct reference counting. - Z3Solver(const Z3Solver &Copy) : Context(Copy.Context), Solver(Copy.Solver) { - Z3_solver_inc_ref(Context.getContext(), Solver); + Z3Solver(const Z3Solver &Copy) : SMTSolver(), Solver(Copy.Solver) { + Z3_solver_inc_ref(toZ3Context(Context)->getContext(), Solver); } /// Provide move constructor - Z3Solver(Z3Solver &&Move) : Context(Move.Context), Solver(nullptr) { + Z3Solver(Z3Solver &&Move) : SMTSolver(), Solver(nullptr) { *this = std::move(Move); } @@ -333,7 +323,7 @@ Z3Solver &operator=(Z3Solver &&Move) { if (this != &Move) { if (Solver) - Z3_solver_dec_ref(Context.getContext(), Solver); + Z3_solver_dec_ref(toZ3Context(Context)->getContext(), Solver); Solver = Move.Solver; Move.Solver = nullptr; } @@ -342,486 +332,440 @@ ~Z3Solver() { if (Solver) - Z3_solver_dec_ref(Context.getContext(), Solver); + Z3_solver_dec_ref(toZ3Context(Context)->getContext(), Solver); } - /// Given a constraint, add it to the solver - void addConstraint(const Z3Expr &Exp) { - Z3_solver_assert(Context.getContext(), Solver, Exp.AST); + /// Get a model from the solver. Caller should check the model is + /// satisfiable. + Z3Model getModel() { + return Z3Model(Context, Z3_solver_get_model( + toZ3Context(Context)->getContext(), Solver)); } - // Return a boolean sort. - Z3Sort getBoolSort() { - return Z3Sort(Context, Z3_mk_bool_sort(Context.getContext())); + void addConstraint(const SMTExpr &Exp) override { + Z3_solver_assert(toZ3Context(Context)->getContext(), Solver, + toZ3Expr(Exp)->getExpr()); } - // Return an appropriate bitvector sort for the given bitwidth. - Z3Sort getBitvectorSort(unsigned BitWidth) { - return Z3Sort(Context, Z3_mk_bv_sort(Context.getContext(), BitWidth)); + SMTSort getBoolSort() override { + return Z3Sort(Context, Z3_mk_bool_sort(toZ3Context(Context)->getContext())); } - // Return an appropriate floating-point sort for the given bitwidth. - Z3Sort getFloatSort(unsigned BitWidth) { - Z3_sort Sort; + SMTSort getBitvectorSort(unsigned BitWidth) override { + return Z3Sort(Context, + Z3_mk_bv_sort(toZ3Context(Context)->getContext(), BitWidth)); + } - switch (BitWidth) { - default: - llvm_unreachable("Unsupported floating-point bitwidth!"); - break; - case 16: - Sort = Z3_mk_fpa_sort_16(Context.getContext()); - break; - case 32: - Sort = Z3_mk_fpa_sort_32(Context.getContext()); - break; - case 64: - Sort = Z3_mk_fpa_sort_64(Context.getContext()); - break; - case 128: - Sort = Z3_mk_fpa_sort_128(Context.getContext()); - break; - } - return Z3Sort(Context, Sort); + SMTSort getFloat16Sort() override { + return Z3Sort(Context, + Z3_mk_fpa_sort_16(toZ3Context(Context)->getContext())); } - // Return an appropriate sort, given a QualType - Z3Sort MkSort(const QualType &Ty, unsigned BitWidth) { - if (Ty->isBooleanType()) - return getBoolSort(); + SMTSort getFloat32Sort() override { + return Z3Sort(Context, + Z3_mk_fpa_sort_32(toZ3Context(Context)->getContext())); + } - if (Ty->isRealFloatingType()) - return getFloatSort(BitWidth); + SMTSort getFloat64Sort() override { + return Z3Sort(Context, + Z3_mk_fpa_sort_64(toZ3Context(Context)->getContext())); + } - return getBitvectorSort(BitWidth); + SMTSort getFloat128Sort() override { + return Z3Sort(Context, + Z3_mk_fpa_sort_128(toZ3Context(Context)->getContext())); } - // Return an appropriate sort for the given AST. - Z3Sort getSort(Z3_ast AST) { - return Z3Sort(Context, Z3_get_sort(Context.getContext(), AST)); + SMTSort getSort(SMTExpr Exp) override { + return Z3Sort(Context, Z3_get_sort(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); } - /// Given a program state, construct the logical conjunction and add it to - /// the solver - void addStateConstraints(ProgramStateRef State) { - // TODO: Don't add all the constraints, only the relevant ones - ConstraintZ3Ty CZ = State->get(); - ConstraintZ3Ty::iterator I = CZ.begin(), IE = CZ.end(); + SMTExpr getFloatRoundingMode() override { + // TODO: Don't assume nearest ties to even rounding mode + return Z3Expr(Context, Z3_mk_fpa_rne(toZ3Context(Context)->getContext())); + } - // Construct the logical AND of all the constraints - if (I != IE) { - std::vector ASTs; + SMTExpr mkBVNeg(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_bvneg(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } - while (I != IE) - ASTs.push_back(I++->second.AST); + SMTExpr mkBVNot(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_bvnot(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } - Z3Expr Conj = fromNBinOp(BO_LAnd, ASTs); - addConstraint(Conj); - } + SMTExpr mkNot(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_not(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); } - // Return an appropriate floating-point rounding mode. - Z3Expr getFloatRoundingMode() { - // TODO: Don't assume nearest ties to even rounding mode - return Z3Expr(Context, Z3_mk_fpa_rne(Context.getContext())); + SMTExpr mkBVAdd(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvadd(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); } - /// Construct a Z3Expr from a unary operator, given a Z3_context. - Z3Expr fromUnOp(const UnaryOperator::Opcode Op, const Z3Expr &Exp) { - Z3_ast AST; + SMTExpr mkBVSub(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsub(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExpr mkBVMul(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvmul(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case UO_Minus: - AST = Z3_mk_bvneg(Context.getContext(), Exp.AST); - break; + SMTExpr mkBVSRem(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsrem(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case UO_Not: - AST = Z3_mk_bvnot(Context.getContext(), Exp.AST); - break; + SMTExpr mkBVURem(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvurem(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case UO_LNot: - AST = Z3_mk_not(Context.getContext(), Exp.AST); - break; - } + SMTExpr mkBVSDiv(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsdiv(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); + SMTExpr mkBVUDiv(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvudiv(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); } - /// Construct a Z3Expr from a floating-point unary operator, given a - /// Z3_context. - Z3Expr fromFloatUnOp(const UnaryOperator::Opcode Op, const Z3Expr &Exp) { - Z3_ast AST; + SMTExpr mkBVShl(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvshl(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExpr mkBVAshr(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvashr(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case UO_Minus: - AST = Z3_mk_fpa_neg(Context.getContext(), Exp.AST); - break; + SMTExpr mkBVLshr(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvlshr(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case UO_LNot: - return fromUnOp(Op, Exp); - } + SMTExpr mkBVXor(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvxor(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); + SMTExpr mkBVOr(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvor(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); } - /// Construct a Z3Expr from a n-ary binary operator. - Z3Expr fromNBinOp(const BinaryOperator::Opcode Op, - const std::vector &ASTs) { - Z3_ast AST; + SMTExpr mkBVAnd(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvand(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExpr mkBVUlt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvult(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case BO_LAnd: - AST = Z3_mk_and(Context.getContext(), ASTs.size(), ASTs.data()); - break; + SMTExpr mkBVSlt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvslt(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - case BO_LOr: - AST = Z3_mk_or(Context.getContext(), ASTs.size(), ASTs.data()); - break; - } + SMTExpr mkBVUgt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvugt(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); - } - - /// Construct a Z3Expr from a binary operator, given a Z3_context. - Z3Expr fromBinOp(const Z3Expr &LHS, const BinaryOperator::Opcode Op, - const Z3Expr &RHS, bool isSigned) { - Z3_ast AST; - - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - // Multiplicative operators - case BO_Mul: - AST = Z3_mk_bvmul(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Div: - AST = isSigned ? Z3_mk_bvsdiv(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvudiv(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Rem: - AST = isSigned ? Z3_mk_bvsrem(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvurem(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Additive operators - case BO_Add: - AST = Z3_mk_bvadd(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Sub: - AST = Z3_mk_bvsub(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Bitwise shift operators - case BO_Shl: - AST = Z3_mk_bvshl(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Shr: - AST = isSigned ? Z3_mk_bvashr(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvlshr(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Relational operators - case BO_LT: - AST = isSigned ? Z3_mk_bvslt(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvult(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_GT: - AST = isSigned ? Z3_mk_bvsgt(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvugt(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_LE: - AST = isSigned ? Z3_mk_bvsle(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvule(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_GE: - AST = isSigned ? Z3_mk_bvsge(Context.getContext(), LHS.AST, RHS.AST) - : Z3_mk_bvuge(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Equality operators - case BO_EQ: - AST = Z3_mk_eq(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_NE: - return fromUnOp(UO_LNot, fromBinOp(LHS, BO_EQ, RHS, isSigned)); - break; - - // Bitwise operators - case BO_And: - AST = Z3_mk_bvand(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Xor: - AST = Z3_mk_bvxor(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_Or: - AST = Z3_mk_bvor(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Logical operators - case BO_LAnd: - case BO_LOr: { - std::vector Args = {LHS.AST, RHS.AST}; - return fromNBinOp(Op, Args); - } - } + SMTExpr mkBVSgt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsgt(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); - } - - /// Construct a Z3Expr from a special floating-point binary operator, given - /// a Z3_context. - Z3Expr fromFloatSpecialBinOp(const Z3Expr &LHS, - const BinaryOperator::Opcode Op, - const llvm::APFloat::fltCategory &RHS) { - Z3_ast AST; - - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; - - // Equality operators - case BO_EQ: - switch (RHS) { - case llvm::APFloat::fcInfinity: - AST = Z3_mk_fpa_is_infinite(Context.getContext(), LHS.AST); - break; - case llvm::APFloat::fcNaN: - AST = Z3_mk_fpa_is_nan(Context.getContext(), LHS.AST); - break; - case llvm::APFloat::fcNormal: - AST = Z3_mk_fpa_is_normal(Context.getContext(), LHS.AST); - break; - case llvm::APFloat::fcZero: - AST = Z3_mk_fpa_is_zero(Context.getContext(), LHS.AST); - break; - } - break; - case BO_NE: - return fromFloatUnOp(UO_LNot, fromFloatSpecialBinOp(LHS, BO_EQ, RHS)); - break; - } + SMTExpr mkBVUle(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvule(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); + SMTExpr mkBVSle(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsle(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); } - /// Construct a Z3Expr from a floating-point binary operator, given a - /// Z3_context. - Z3Expr fromFloatBinOp(const Z3Expr &LHS, const BinaryOperator::Opcode Op, - const Z3Expr &RHS) { - Z3_ast AST; + SMTExpr mkBVUge(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvuge(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - assert(getSort(LHS.AST) == getSort(RHS.AST) && - "AST's must have the same sort!"); + SMTExpr mkBVSge(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_bvsge(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - switch (Op) { - default: - llvm_unreachable("Unimplemented opcode"); - break; + SMTExpr mkAnd(SMTExpr LHS, SMTExpr RHS) override { + Z3_ast Args[2] = {toZ3Expr(LHS)->getExpr(), toZ3Expr(RHS)->getExpr()}; + return Z3Expr(Context, + Z3_mk_and(toZ3Context(Context)->getContext(), 2, Args)); + } - // Multiplicative operators - case BO_Mul: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_mul(Context.getContext(), RoundingMode.AST, LHS.AST, - RHS.AST); - break; - } - case BO_Div: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_div(Context.getContext(), RoundingMode.AST, LHS.AST, - RHS.AST); - break; - } - case BO_Rem: - AST = Z3_mk_fpa_rem(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Additive operators - case BO_Add: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_add(Context.getContext(), RoundingMode.AST, LHS.AST, - RHS.AST); - break; - } - case BO_Sub: { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = Z3_mk_fpa_sub(Context.getContext(), RoundingMode.AST, LHS.AST, - RHS.AST); - break; - } + SMTExpr mkOr(SMTExpr LHS, SMTExpr RHS) override { + Z3_ast Args[2] = {toZ3Expr(LHS)->getExpr(), toZ3Expr(RHS)->getExpr()}; + return Z3Expr(Context, + Z3_mk_or(toZ3Context(Context)->getContext(), 2, Args)); + } - // Relational operators - case BO_LT: - AST = Z3_mk_fpa_lt(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_GT: - AST = Z3_mk_fpa_gt(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_LE: - AST = Z3_mk_fpa_leq(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_GE: - AST = Z3_mk_fpa_geq(Context.getContext(), LHS.AST, RHS.AST); - break; - - // Equality operators - case BO_EQ: - AST = Z3_mk_fpa_eq(Context.getContext(), LHS.AST, RHS.AST); - break; - case BO_NE: - return fromFloatUnOp(UO_LNot, fromFloatBinOp(LHS, BO_EQ, RHS)); - break; - - // Logical operators - case BO_LAnd: - case BO_LOr: - return fromBinOp(LHS, Op, RHS, false); - } + SMTExpr mkEqual(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, + Z3_mk_eq(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); - } - - /// Construct a Z3Expr from a SymbolData, given a Z3_context. - Z3Expr fromData(const SymbolID ID, const QualType &Ty, uint64_t BitWidth) { - llvm::Twine Name = "$" + llvm::Twine(ID); - - Z3Sort Sort = MkSort(Ty, BitWidth); - - Z3_symbol Symbol = - Z3_mk_string_symbol(Context.getContext(), Name.str().c_str()); - Z3_ast AST = Z3_mk_const(Context.getContext(), Symbol, Sort.Sort); - return Z3Expr(Context, AST); - } - - /// Construct a Z3Expr from a SymbolCast, given a Z3_context. - Z3Expr fromCast(const Z3Expr &Exp, QualType ToTy, uint64_t ToBitWidth, - QualType FromTy, uint64_t FromBitWidth) { - Z3_ast AST; - - if ((FromTy->isIntegralOrEnumerationType() && - ToTy->isIntegralOrEnumerationType()) || - (FromTy->isAnyPointerType() ^ ToTy->isAnyPointerType()) || - (FromTy->isBlockPointerType() ^ ToTy->isBlockPointerType()) || - (FromTy->isReferenceType() ^ ToTy->isReferenceType())) { - // Special case: Z3 boolean type is distinct from bitvector type, so - // must use if-then-else expression instead of direct cast - if (FromTy->isBooleanType()) { - assert(ToBitWidth > 0 && "BitWidth must be positive!"); - Z3Expr Zero = fromInt("0", ToBitWidth); - Z3Expr One = fromInt("1", ToBitWidth); - AST = Z3_mk_ite(Context.getContext(), Exp.AST, One.AST, Zero.AST); - } else if (ToBitWidth > FromBitWidth) { - AST = FromTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_sign_ext(Context.getContext(), - ToBitWidth - FromBitWidth, Exp.AST) - : Z3_mk_zero_ext(Context.getContext(), - ToBitWidth - FromBitWidth, Exp.AST); - } else if (ToBitWidth < FromBitWidth) { - AST = Z3_mk_extract(Context.getContext(), ToBitWidth - 1, 0, Exp.AST); - } else { - // Both are bitvectors with the same width, ignore the type cast - return Exp; - } - } else if (FromTy->isRealFloatingType() && ToTy->isRealFloatingType()) { - if (ToBitWidth != FromBitWidth) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); - AST = Z3_mk_fpa_to_fp_float(Context.getContext(), RoundingMode.AST, - Exp.AST, Sort.Sort); - } else { - return Exp; - } - } else if (FromTy->isIntegralOrEnumerationType() && - ToTy->isRealFloatingType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); - Z3Sort Sort = getFloatSort(ToBitWidth); - AST = - FromTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_fp_signed(Context.getContext(), RoundingMode.AST, - Exp.AST, Sort.Sort) - : Z3_mk_fpa_to_fp_unsigned(Context.getContext(), RoundingMode.AST, - Exp.AST, Sort.Sort); - } else if (FromTy->isRealFloatingType() && - ToTy->isIntegralOrEnumerationType()) { - Z3Expr RoundingMode = getFloatRoundingMode(); - AST = ToTy->isSignedIntegerOrEnumerationType() - ? Z3_mk_fpa_to_sbv(Context.getContext(), RoundingMode.AST, - Exp.AST, ToBitWidth) - : Z3_mk_fpa_to_ubv(Context.getContext(), RoundingMode.AST, - Exp.AST, ToBitWidth); - } else { - llvm_unreachable("Unsupported explicit type cast!"); - } + SMTExpr mkFPNeg(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_fpa_neg(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkFPIsInfinite(SMTExpr Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_infinite(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkFPIsNaN(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_fpa_is_nan(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkFPIsNormal(SMTExpr Exp) override { + return Z3Expr(Context, + Z3_mk_fpa_is_normal(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkFPIsZero(SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_fpa_is_zero(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkFPMul(SMTExpr LHS, SMTExpr RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_mul(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr(), + toZ3Expr(RoundingMode)->getExpr())); + } + + SMTExpr mkFPDiv(SMTExpr LHS, SMTExpr RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_div(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr(), + toZ3Expr(RoundingMode)->getExpr())); + } + + SMTExpr mkFPRem(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_rem(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPAdd(SMTExpr LHS, SMTExpr RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_add(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr(), + toZ3Expr(RoundingMode)->getExpr())); + } + + SMTExpr mkFPSub(SMTExpr LHS, SMTExpr RHS) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, Z3_mk_fpa_sub(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr(), + toZ3Expr(RoundingMode)->getExpr())); + } + + SMTExpr mkFPLt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_lt(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPGt(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_gt(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPLe(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_leq(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPGe(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_geq(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPEqual(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_fpa_eq(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } - return Z3Expr(Context, AST); + SMTExpr mkIte(SMTExpr Cond, SMTExpr T, SMTExpr F) override { + return Z3Expr(Context, + Z3_mk_ite(toZ3Context(Context)->getContext(), + toZ3Expr(Cond)->getExpr(), toZ3Expr(T)->getExpr(), + toZ3Expr(F)->getExpr())); } - /// Construct a Z3Expr from a boolean, given a Z3_context. - Z3Expr fromBoolean(const bool Bool) { - Z3_ast AST = Bool ? Z3_mk_true(Context.getContext()) - : Z3_mk_false(Context.getContext()); - return Z3Expr(Context, AST); + SMTExpr mkSignExt(unsigned i, SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_sign_ext(toZ3Context(Context)->getContext(), i, + toZ3Expr(Exp)->getExpr())); } - /// Construct a Z3Expr from a finite APFloat, given a Z3_context. - Z3Expr fromAPFloat(const llvm::APFloat &Float) { - Z3_ast AST; - Z3Sort Sort = + SMTExpr mkZeroExt(unsigned i, SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_zero_ext(toZ3Context(Context)->getContext(), i, + toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkExtract(unsigned High, unsigned Low, SMTExpr Exp) override { + return Z3Expr(Context, Z3_mk_extract(toZ3Context(Context)->getContext(), + High, Low, toZ3Expr(Exp)->getExpr())); + } + + SMTExpr mkConcat(SMTExpr LHS, SMTExpr RHS) override { + return Z3Expr(Context, Z3_mk_concat(toZ3Context(Context)->getContext(), + toZ3Expr(LHS)->getExpr(), + toZ3Expr(RHS)->getExpr())); + } + + SMTExpr mkFPtoFP(SMTExpr From, SMTSort To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_fp_float(toZ3Context(Context)->getContext(), + toZ3Expr(RoundingMode)->getExpr(), + toZ3Expr(From)->getExpr(), + toZ3Sort(To)->getSort())); + } + + SMTExpr mkFPtoSBV(SMTExpr From, SMTSort To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_fp_signed(toZ3Context(Context)->getContext(), + toZ3Expr(RoundingMode)->getExpr(), + toZ3Expr(From)->getExpr(), + toZ3Sort(To)->getSort())); + } + + SMTExpr mkFPtoUBV(SMTExpr From, SMTSort To) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_fp_unsigned(toZ3Context(Context)->getContext(), + toZ3Expr(RoundingMode)->getExpr(), + toZ3Expr(From)->getExpr(), + toZ3Sort(To)->getSort())); + } + + SMTExpr mkSBVtoFP(SMTExpr From, unsigned ToWidth) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_sbv(toZ3Context(Context)->getContext(), + toZ3Expr(RoundingMode)->getExpr(), + toZ3Expr(From)->getExpr(), ToWidth)); + } + + SMTExpr mkUBVtoFP(SMTExpr From, unsigned ToWidth) override { + SMTExpr RoundingMode = getFloatRoundingMode(); + return Z3Expr(Context, + Z3_mk_fpa_to_ubv(toZ3Context(Context)->getContext(), + toZ3Expr(RoundingMode)->getExpr(), + toZ3Expr(From)->getExpr(), ToWidth)); + } + + SMTExpr mkBoolean(const bool b) override { + return Z3Expr(Context, b ? Z3_mk_true(toZ3Context(Context)->getContext()) + : Z3_mk_false(toZ3Context(Context)->getContext())); + } + + SMTExpr mkBitvector(const llvm::APSInt Int, unsigned BitWidth) override { + SMTSort Sort = getBitvectorSort(BitWidth); + return Z3Expr(Context, Z3_mk_numeral(toZ3Context(Context)->getContext(), + Int.toString(10).c_str(), + toZ3Sort(Sort)->getSort())); + } + + SMTExpr mkFloat(const llvm::APFloat Float) override { + SMTSort Sort = getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics())); llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false); - Z3Expr Z3Int = fromAPSInt(Int); - AST = Z3_mk_fpa_to_fp_bv(Context.getContext(), Z3Int.AST, Sort.Sort); - return Z3Expr(Context, AST); + SMTExpr Z3Int = mkBitvector(Int, Int.getBitWidth()); + return Z3Expr(Context, + Z3_mk_fpa_to_fp_bv(toZ3Context(Context)->getContext(), + toZ3Expr(Z3Int)->getExpr(), + toZ3Sort(Sort)->getSort())); + } + + SMTExpr mkSymbol(const char *Name, SMTSort Sort) override { + return Z3Expr(Context, + Z3_mk_const(toZ3Context(Context)->getContext(), + Z3_mk_string_symbol( + toZ3Context(Context)->getContext(), Name), + toZ3Sort(Sort)->getSort())); } - /// Construct a Z3Expr from an APSInt, given a Z3_context. - Z3Expr fromAPSInt(const llvm::APSInt &Int) { - Z3Sort Sort = getBitvectorSort(Int.getBitWidth()); - Z3_ast AST = Z3_mk_numeral(Context.getContext(), Int.toString(10).c_str(), - Sort.Sort); - return Z3Expr(Context, AST); + const llvm::APSInt getBitvector(SMTExpr Exp) override { + return llvm::APSInt(Z3_get_numeral_string( + toZ3Context(Context)->getContext(), toZ3Expr(Exp)->getExpr())); } - /// Construct a Z3Expr from an integer, given a Z3_context. - Z3Expr fromInt(const char *Int, uint64_t BitWidth) { - Z3Sort Sort = getBitvectorSort(BitWidth); - Z3_ast AST = Z3_mk_numeral(Context.getContext(), Int, Sort.Sort); - return Z3Expr(Context, AST); + bool getBoolean(SMTExpr Exp) override { + return Z3_get_bool_value(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr()) == Z3_L_TRUE; } - /// Construct an APFloat from a Z3Expr, given the AST representation - bool toAPFloat(const Z3Sort &Sort, const Z3_ast &AST, llvm::APFloat &Float, - bool useSemantics = true) { - assert(Sort.getSortKind() == Z3_FLOATING_POINT_SORT && - "Unsupported sort to floating-point!"); + bool toAPFloat(const SMTSort &Sort, const SMTExpr &AST, llvm::APFloat &Float, + bool useSemantics) { + assert(Sort.isFloatSort() && "Unsupported sort to floating-point!"); llvm::APSInt Int(Sort.getFloatSortSize(), true); const llvm::fltSemantics &Semantics = - Z3Expr::getFloatSemantics(Sort.getFloatSortSize()); - Z3Sort BVSort = getBitvectorSort(Sort.getFloatSortSize()); + getFloatSemantics(Sort.getFloatSortSize()); + SMTSort BVSort = getBitvectorSort(Sort.getFloatSortSize()); if (!toAPSInt(BVSort, AST, Int, true)) { return false; } - if (useSemantics && - !Z3Expr::areEquivalent(Float.getSemantics(), Semantics)) { + if (useSemantics && !areEquivalent(Float.getSemantics(), Semantics)) { assert(false && "Floating-point types don't match!"); return false; } @@ -830,899 +774,103 @@ return true; } - /// Construct an APSInt from a Z3Expr, given the AST representation - bool toAPSInt(const Z3Sort &Sort, const Z3_ast &AST, llvm::APSInt &Int, - bool useSemantics = true) { - switch (Sort.getSortKind()) { - default: - llvm_unreachable("Unsupported sort to integer!"); - case Z3_BV_SORT: { + bool toAPSInt(const SMTSort &Sort, const SMTExpr &AST, llvm::APSInt &Int, + bool useSemantics) { + if (Sort.isBitvectorSort()) { if (useSemantics && Int.getBitWidth() != Sort.getBitvectorSortSize()) { assert(false && "Bitvector types don't match!"); return false; } - uint64_t Value[2]; - // Force cast because Z3 defines __uint64 to be a unsigned long long - // type, which isn't compatible with a unsigned long type, even if they - // are the same size. - Z3_get_numeral_uint64(Context.getContext(), AST, - reinterpret_cast<__uint64 *>(&Value[0])); - if (Sort.getBitvectorSortSize() <= 64) { - Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value[0]), - Int.isUnsigned()); - } else if (Sort.getBitvectorSortSize() == 128) { - Z3Expr ASTHigh = - Z3Expr(Context, Z3_mk_extract(Context.getContext(), 127, 64, AST)); - Z3_get_numeral_uint64(Context.getContext(), AST, - reinterpret_cast<__uint64 *>(&Value[1])); - Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), Value), - Int.isUnsigned()); - } else { - assert(false && "Bitwidth not supported!"); - return false; - } + Int = getBitvector(AST); return true; } - case Z3_BOOL_SORT: + + if (Sort.isBooleanSort()) { if (useSemantics && Int.getBitWidth() < 1) { assert(false && "Boolean type doesn't match!"); return false; } - Int = llvm::APSInt( - llvm::APInt(Int.getBitWidth(), - Z3_get_bool_value(Context.getContext(), AST) == Z3_L_TRUE - ? 1 - : 0), - Int.isUnsigned()); + + Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), getBoolean(AST)), + Int.isUnsigned()); return true; } + + llvm_unreachable("Unsupported sort to integer!"); } /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APSInt &Int) { - Z3_func_decl Func = Z3_get_app_decl( - Context.getContext(), Z3_to_app(Context.getContext(), Exp.AST)); - if (Z3_model_has_interp(Context.getContext(), Model.Model, Func) != - Z3_L_TRUE) + bool getInterpretation(const SMTExpr &Exp, llvm::APSInt &Int) override { + Z3Model Model = getModel(); + Z3_func_decl Func = + Z3_get_app_decl(toZ3Context(Context)->getContext(), + Z3_to_app(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + if (Z3_model_has_interp(toZ3Context(Context)->getContext(), Model.Model, + Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.getContext(), Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExpr Assign = Z3Expr( + Context, Z3_model_get_const_interp(toZ3Context(Context)->getContext(), + Model.Model, Func)); + SMTSort Sort = getSort(Assign); return toAPSInt(Sort, Assign, Int, true); } /// Given an expression and a model, extract the value of this operand in /// the model. - bool getInterpretation(const Z3Model Model, const Z3Expr &Exp, - llvm::APFloat &Float) { - Z3_func_decl Func = Z3_get_app_decl( - Context.getContext(), Z3_to_app(Context.getContext(), Exp.AST)); - if (Z3_model_has_interp(Context.getContext(), Model.Model, Func) != - Z3_L_TRUE) + bool getInterpretation(const SMTExpr &Exp, llvm::APFloat &Float) override { + Z3Model Model = getModel(); + Z3_func_decl Func = + Z3_get_app_decl(toZ3Context(Context)->getContext(), + Z3_to_app(toZ3Context(Context)->getContext(), + toZ3Expr(Exp)->getExpr())); + if (Z3_model_has_interp(toZ3Context(Context)->getContext(), Model.Model, + Func) != Z3_L_TRUE) return false; - Z3_ast Assign = - Z3_model_get_const_interp(Context.getContext(), Model.Model, Func); - Z3Sort Sort = getSort(Assign); + SMTExpr Assign = Z3Expr( + Context, Z3_model_get_const_interp(toZ3Context(Context)->getContext(), + Model.Model, Func)); + SMTSort Sort = getSort(Assign); return toAPFloat(Sort, Assign, Float, true); } - // Callback function for doCast parameter on APSInt type. - llvm::APSInt castAPSInt(const llvm::APSInt &V, QualType ToTy, - uint64_t ToWidth, QualType FromTy, - uint64_t FromWidth) { - APSIntType TargetType(ToWidth, !ToTy->isSignedIntegerOrEnumerationType()); - return TargetType.convert(V); - } - /// Check if the constraints are satisfiable - Z3_lbool check() { return Z3_solver_check(Context.getContext(), Solver); } + bool check() const override { + return Z3_solver_check(Context.getContext(), Solver); + } /// Push the current solver state - void push() { return Z3_solver_push(Context.getContext(), Solver); } + void push() override { return Z3_solver_push(Context.getContext(), Solver); } /// Pop the previous solver state - void pop(unsigned NumStates = 1) { - assert(Z3_solver_get_num_scopes(Context.getContext(), Solver) >= NumStates); + void pop(unsigned NumStates = 1) override { + assert(Z3_solver_get_num_scopes(toZ3Context(Context)->getContext(), + Solver) >= NumStates); return Z3_solver_pop(Context.getContext(), Solver, NumStates); } - /// Get a model from the solver. Caller should check the model is - /// satisfiable. - Z3Model getModel() { - return Z3Model(Context, Z3_solver_get_model(Context.getContext(), Solver)); - } - /// Reset the solver and remove all constraints. - void reset() { Z3_solver_reset(Context.getContext(), Solver); } + void reset() const override { Z3_solver_reset(Context.getContext(), Solver); } - void print(raw_ostream &OS) const { + void print(raw_ostream &OS) const override { OS << Z3_solver_to_string(Context.getContext(), Solver); } - - LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); } }; // end class Z3Solver class Z3ConstraintManager : public SMTConstraintManager { - mutable Z3Solver Solver; + Z3Solver S; public: Z3ConstraintManager(SubEngine *SE, SValBuilder &SB) - : SMTConstraintManager(SE, SB) {} - - //===------------------------------------------------------------------===// - // Implementation for Refutation. - //===------------------------------------------------------------------===// - - void addRangeConstraints(clang::ento::ConstraintRangeTy CR) override; - - ConditionTruthVal isModelFeasible() override; - - LLVM_DUMP_METHOD void dump() const override; - - //===------------------------------------------------------------------===// - // Implementation for interface from ConstraintManager. - //===------------------------------------------------------------------===// - - bool canReasonAbout(SVal X) const override; - - ConditionTruthVal checkNull(ProgramStateRef State, SymbolRef Sym) override; - - const llvm::APSInt *getSymVal(ProgramStateRef State, - SymbolRef Sym) const override; - - ProgramStateRef removeDeadBindings(ProgramStateRef St, - SymbolReaper &SymReaper) override; - - void print(ProgramStateRef St, raw_ostream &Out, const char *nl, - const char *sep) override; - - //===------------------------------------------------------------------===// - // Implementation for interface from SimpleConstraintManager. - //===------------------------------------------------------------------===// - - ProgramStateRef assumeSym(ProgramStateRef state, SymbolRef Sym, - bool Assumption) override; - - ProgramStateRef assumeSymInclusiveRange(ProgramStateRef State, SymbolRef Sym, - const llvm::APSInt &From, - const llvm::APSInt &To, - bool InRange) override; - - ProgramStateRef assumeSymUnsupported(ProgramStateRef State, SymbolRef Sym, - bool Assumption) override; - -private: - //===------------------------------------------------------------------===// - // Internal implementation. - //===------------------------------------------------------------------===// - - // Check whether a new model is satisfiable, and update the program state. - ProgramStateRef assumeZ3Expr(ProgramStateRef State, SymbolRef Sym, - const Z3Expr &Exp); - - // Generate and check a Z3 model, using the given constraint. - Z3_lbool checkZ3Model(ProgramStateRef State, const Z3Expr &Exp) const; - - // Generate a Z3Expr that represents the given symbolic expression. - // Sets the hasComparison parameter if the expression has a comparison - // operator. - // Sets the RetTy parameter to the final return type after promotions and - // casts. - Z3Expr getZ3Expr(SymbolRef Sym, QualType *RetTy = nullptr, - bool *hasComparison = nullptr) const; - - // Generate a Z3Expr that takes the logical not of an expression. - Z3Expr getZ3NotExpr(const Z3Expr &Exp) const; - - // Generate a Z3Expr that compares the expression to zero. - Z3Expr getZ3ZeroExpr(const Z3Expr &Exp, QualType RetTy, - bool Assumption) const; - - // Recursive implementation to unpack and generate symbolic expression. - // Sets the hasComparison and RetTy parameters. See getZ3Expr(). - Z3Expr getZ3SymExpr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const; - - // Wrapper to generate Z3Expr from SymbolData. - Z3Expr getZ3DataExpr(const SymbolID ID, QualType Ty) const; - - // Wrapper to generate Z3Expr from SymbolCast. - Z3Expr getZ3CastExpr(const Z3Expr &Exp, QualType FromTy, QualType Ty) const; - - // Wrapper to generate Z3Expr from BinarySymExpr. - // Sets the hasComparison and RetTy parameters. See getZ3Expr(). - Z3Expr getZ3SymBinExpr(const BinarySymExpr *BSE, bool *hasComparison, - QualType *RetTy) const; - - // Wrapper to generate Z3Expr from unpacked binary symbolic expression. - // Sets the RetTy parameter. See getZ3Expr(). - Z3Expr getZ3BinExpr(const Z3Expr &LHS, QualType LTy, - BinaryOperator::Opcode Op, const Z3Expr &RHS, - QualType RTy, QualType *RetTy) const; - - // Wrapper to generate Z3Expr from a range. If From == To, an equality will - // be created instead. - Z3Expr getZ3RangeExpr(SymbolRef Sym, const llvm::APSInt &From, - const llvm::APSInt &To, bool InRange); - - //===------------------------------------------------------------------===// - // Helper functions. - //===------------------------------------------------------------------===// - - // Recover the QualType of an APSInt. - // TODO: Refactor to put elsewhere - QualType getAPSIntType(const llvm::APSInt &Int) const; - - // Get the QualTy for the input APSInt, and fix it if it has a bitwidth of 1. - std::pair fixAPSInt(const llvm::APSInt &Int) const; - - // Perform implicit type conversion on binary symbolic expressions. - // May modify all input parameters. - // TODO: Refactor to use built-in conversion functions - void doTypeConversion(Z3Expr &LHS, Z3Expr &RHS, QualType <y, - QualType &RTy) const; - - // Perform implicit integer type conversion. - // May modify all input parameters. - // TODO: Refactor to use Sema::handleIntegerConversion() - template - void doIntTypeConversion(T &LHS, QualType <y, T &RHS, QualType &RTy) const; - - // Perform implicit floating-point type conversion. - // May modify all input parameters. - // TODO: Refactor to use Sema::handleFloatConversion() - template - void doFloatTypeConversion(T &LHS, QualType <y, T &RHS, - QualType &RTy) const; + : SMTConstraintManager(SE, SB, S) {} }; // end class Z3ConstraintManager } // end anonymous namespace -ProgramStateRef Z3ConstraintManager::assumeSym(ProgramStateRef State, - SymbolRef Sym, bool Assumption) { - QualType RetTy; - bool hasComparison; - - Z3Expr Exp = getZ3Expr(Sym, &RetTy, &hasComparison); - // Create zero comparison for implicit boolean cast, with reversed assumption - if (!hasComparison && !RetTy->isBooleanType()) - return assumeZ3Expr(State, Sym, getZ3ZeroExpr(Exp, RetTy, !Assumption)); - - return assumeZ3Expr(State, Sym, Assumption ? Exp : getZ3NotExpr(Exp)); -} - -ProgramStateRef Z3ConstraintManager::assumeSymInclusiveRange( - ProgramStateRef State, SymbolRef Sym, const llvm::APSInt &From, - const llvm::APSInt &To, bool InRange) { - return assumeZ3Expr(State, Sym, getZ3RangeExpr(Sym, From, To, InRange)); -} - -ProgramStateRef Z3ConstraintManager::assumeSymUnsupported(ProgramStateRef State, - SymbolRef Sym, - bool Assumption) { - // Skip anything that is unsupported - return State; -} - -bool Z3ConstraintManager::canReasonAbout(SVal X) const { - const TargetInfo &TI = getBasicVals().getContext().getTargetInfo(); - - Optional SymVal = X.getAs(); - if (!SymVal) - return true; - - const SymExpr *Sym = SymVal->getSymbol(); - QualType Ty = Sym->getType(); - - // Complex types are not modeled - if (Ty->isComplexType() || Ty->isComplexIntegerType()) - return false; - - // Non-IEEE 754 floating-point types are not modeled - if ((Ty->isSpecificBuiltinType(BuiltinType::LongDouble) && - (&TI.getLongDoubleFormat() == &llvm::APFloat::x87DoubleExtended() || - &TI.getLongDoubleFormat() == &llvm::APFloat::PPCDoubleDouble()))) - return false; - - if (isa(Sym)) - return true; - - SValBuilder &SVB = getSValBuilder(); - - if (const SymbolCast *SC = dyn_cast(Sym)) - return canReasonAbout(SVB.makeSymbolVal(SC->getOperand())); - - if (const BinarySymExpr *BSE = dyn_cast(Sym)) { - if (const SymIntExpr *SIE = dyn_cast(BSE)) - return canReasonAbout(SVB.makeSymbolVal(SIE->getLHS())); - - if (const IntSymExpr *ISE = dyn_cast(BSE)) - return canReasonAbout(SVB.makeSymbolVal(ISE->getRHS())); - - if (const SymSymExpr *SSE = dyn_cast(BSE)) - return canReasonAbout(SVB.makeSymbolVal(SSE->getLHS())) && - canReasonAbout(SVB.makeSymbolVal(SSE->getRHS())); - } - - llvm_unreachable("Unsupported expression to reason about!"); -} - -ConditionTruthVal Z3ConstraintManager::checkNull(ProgramStateRef State, - SymbolRef Sym) { - QualType RetTy; - // The expression may be casted, so we cannot call getZ3DataExpr() directly - Z3Expr VarExp = getZ3Expr(Sym, &RetTy); - Z3Expr Exp = getZ3ZeroExpr(VarExp, RetTy, true); - // Negate the constraint - Z3Expr NotExp = getZ3ZeroExpr(VarExp, RetTy, false); - - Solver.reset(); - Solver.addStateConstraints(State); - - Solver.push(); - Solver.addConstraint(Exp); - Z3_lbool isSat = Solver.check(); - - Solver.pop(); - Solver.addConstraint(NotExp); - Z3_lbool isNotSat = Solver.check(); - - // Zero is the only possible solution - if (isSat == Z3_L_TRUE && isNotSat == Z3_L_FALSE) - return true; - // Zero is not a solution - else if (isSat == Z3_L_FALSE && isNotSat == Z3_L_TRUE) - return false; - - // Zero may be a solution - return ConditionTruthVal(); -} - -const llvm::APSInt *Z3ConstraintManager::getSymVal(ProgramStateRef State, - SymbolRef Sym) const { - BasicValueFactory &BVF = getBasicVals(); - ASTContext &Ctx = BVF.getContext(); - - if (const SymbolData *SD = dyn_cast(Sym)) { - QualType Ty = Sym->getType(); - assert(!Ty->isRealFloatingType()); - llvm::APSInt Value(Ctx.getTypeSize(Ty), - !Ty->isSignedIntegerOrEnumerationType()); - - Z3Expr Exp = getZ3DataExpr(SD->getSymbolID(), Ty); - - Solver.reset(); - Solver.addStateConstraints(State); - - // Constraints are unsatisfiable - if (Solver.check() != Z3_L_TRUE) - return nullptr; - - Z3Model Model = Solver.getModel(); - // Model does not assign interpretation - if (!Solver.getInterpretation(Model, Exp, Value)) - return nullptr; - - // A value has been obtained, check if it is the only value - Z3Expr NotExp = Solver.fromBinOp( - Exp, BO_NE, - Ty->isBooleanType() ? Solver.fromBoolean(Value.getBoolValue()) - : Solver.fromAPSInt(Value), - false); - - Solver.addConstraint(NotExp); - if (Solver.check() == Z3_L_TRUE) - return nullptr; - - // This is the only solution, store it - return &BVF.getValue(Value); - } else if (const SymbolCast *SC = dyn_cast(Sym)) { - SymbolRef CastSym = SC->getOperand(); - QualType CastTy = SC->getType(); - // Skip the void type - if (CastTy->isVoidType()) - return nullptr; - - const llvm::APSInt *Value; - if (!(Value = getSymVal(State, CastSym))) - return nullptr; - return &BVF.Convert(SC->getType(), *Value); - } else if (const BinarySymExpr *BSE = dyn_cast(Sym)) { - const llvm::APSInt *LHS, *RHS; - if (const SymIntExpr *SIE = dyn_cast(BSE)) { - LHS = getSymVal(State, SIE->getLHS()); - RHS = &SIE->getRHS(); - } else if (const IntSymExpr *ISE = dyn_cast(BSE)) { - LHS = &ISE->getLHS(); - RHS = getSymVal(State, ISE->getRHS()); - } else if (const SymSymExpr *SSM = dyn_cast(BSE)) { - // Early termination to avoid expensive call - LHS = getSymVal(State, SSM->getLHS()); - RHS = LHS ? getSymVal(State, SSM->getRHS()) : nullptr; - } else { - llvm_unreachable("Unsupported binary expression to get symbol value!"); - } - - if (!LHS || !RHS) - return nullptr; - - llvm::APSInt ConvertedLHS, ConvertedRHS; - QualType LTy, RTy; - std::tie(ConvertedLHS, LTy) = fixAPSInt(*LHS); - std::tie(ConvertedRHS, RTy) = fixAPSInt(*RHS); - doIntTypeConversion(ConvertedLHS, LTy, - ConvertedRHS, RTy); - return BVF.evalAPSInt(BSE->getOpcode(), ConvertedLHS, ConvertedRHS); - } - - llvm_unreachable("Unsupported expression to get symbol value!"); -} - -ProgramStateRef -Z3ConstraintManager::removeDeadBindings(ProgramStateRef State, - SymbolReaper &SymReaper) { - ConstraintZ3Ty CZ = State->get(); - ConstraintZ3Ty::Factory &CZFactory = State->get_context(); - - for (ConstraintZ3Ty::iterator I = CZ.begin(), E = CZ.end(); I != E; ++I) { - if (SymReaper.maybeDead(I->first)) - CZ = CZFactory.remove(CZ, *I); - } - - return State->set(CZ); -} - -void Z3ConstraintManager::addRangeConstraints(ConstraintRangeTy CR) { - Solver.reset(); - - for (const auto &I : CR) { - SymbolRef Sym = I.first; - - Z3Expr Constraints = Solver.fromBoolean(false); - for (const auto &Range : I.second) { - Z3Expr SymRange = - getZ3RangeExpr(Sym, Range.From(), Range.To(), /*InRange=*/true); - - // FIXME: the last argument (isSigned) is not used when generating the - // or expression, as both arguments are booleans - Constraints = - Solver.fromBinOp(Constraints, BO_LOr, SymRange, /*IsSigned=*/true); - } - Solver.addConstraint(Constraints); - } -} - -clang::ento::ConditionTruthVal Z3ConstraintManager::isModelFeasible() { - if (Solver.check() == Z3_L_FALSE) - return false; - - return ConditionTruthVal(); -} - -LLVM_DUMP_METHOD void Z3ConstraintManager::dump() const { Solver.dump(); } - -//===------------------------------------------------------------------===// -// Internal implementation. -//===------------------------------------------------------------------===// - -ProgramStateRef Z3ConstraintManager::assumeZ3Expr(ProgramStateRef State, - SymbolRef Sym, - const Z3Expr &Exp) { - // Check the model, avoid simplifying AST to save time - if (checkZ3Model(State, Exp) == Z3_L_TRUE) - return State->add(std::make_pair(Sym, Exp)); - - return nullptr; -} - -Z3_lbool Z3ConstraintManager::checkZ3Model(ProgramStateRef State, - const Z3Expr &Exp) const { - Solver.reset(); - Solver.addConstraint(Exp); - Solver.addStateConstraints(State); - return Solver.check(); -} - -Z3Expr Z3ConstraintManager::getZ3Expr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const { - if (hasComparison) { - *hasComparison = false; - } - - return getZ3SymExpr(Sym, RetTy, hasComparison); -} - -Z3Expr Z3ConstraintManager::getZ3NotExpr(const Z3Expr &Exp) const { - return Solver.fromUnOp(UO_LNot, Exp); -} - -Z3Expr Z3ConstraintManager::getZ3ZeroExpr(const Z3Expr &Exp, QualType Ty, - bool Assumption) const { - ASTContext &Ctx = getBasicVals().getContext(); - if (Ty->isRealFloatingType()) { - llvm::APFloat Zero = llvm::APFloat::getZero(Ctx.getFloatTypeSemantics(Ty)); - return Solver.fromFloatBinOp(Exp, Assumption ? BO_EQ : BO_NE, - Solver.fromAPFloat(Zero)); - } else if (Ty->isIntegralOrEnumerationType() || Ty->isAnyPointerType() || - Ty->isBlockPointerType() || Ty->isReferenceType()) { - bool isSigned = Ty->isSignedIntegerOrEnumerationType(); - // Skip explicit comparison for boolean types - if (Ty->isBooleanType()) - return Assumption ? getZ3NotExpr(Exp) : Exp; - return Solver.fromBinOp(Exp, Assumption ? BO_EQ : BO_NE, - Solver.fromInt("0", Ctx.getTypeSize(Ty)), isSigned); - } - - llvm_unreachable("Unsupported type for zero value!"); -} - -Z3Expr Z3ConstraintManager::getZ3SymExpr(SymbolRef Sym, QualType *RetTy, - bool *hasComparison) const { - if (const SymbolData *SD = dyn_cast(Sym)) { - if (RetTy) - *RetTy = Sym->getType(); - - return getZ3DataExpr(SD->getSymbolID(), Sym->getType()); - } else if (const SymbolCast *SC = dyn_cast(Sym)) { - if (RetTy) - *RetTy = Sym->getType(); - - QualType FromTy; - Z3Expr Exp = getZ3SymExpr(SC->getOperand(), &FromTy, hasComparison); - // Casting an expression with a comparison invalidates it. Note that this - // must occur after the recursive call above. - // e.g. (signed char) (x > 0) - if (hasComparison) - *hasComparison = false; - return getZ3CastExpr(Exp, FromTy, Sym->getType()); - } else if (const BinarySymExpr *BSE = dyn_cast(Sym)) { - Z3Expr Exp = getZ3SymBinExpr(BSE, hasComparison, RetTy); - // Set the hasComparison parameter, in post-order traversal order. - if (hasComparison) - *hasComparison = BinaryOperator::isComparisonOp(BSE->getOpcode()); - return Exp; - } - - llvm_unreachable("Unsupported SymbolRef type!"); -} - -Z3Expr Z3ConstraintManager::getZ3DataExpr(const SymbolID ID, - QualType Ty) const { - ASTContext &Ctx = getBasicVals().getContext(); - return Solver.fromData(ID, Ty, Ctx.getTypeSize(Ty)); -} - -Z3Expr Z3ConstraintManager::getZ3CastExpr(const Z3Expr &Exp, QualType FromTy, - QualType ToTy) const { - ASTContext &Ctx = getBasicVals().getContext(); - return Solver.fromCast(Exp, ToTy, Ctx.getTypeSize(ToTy), FromTy, - Ctx.getTypeSize(FromTy)); -} - -Z3Expr Z3ConstraintManager::getZ3SymBinExpr(const BinarySymExpr *BSE, - bool *hasComparison, - QualType *RetTy) const { - QualType LTy, RTy; - BinaryOperator::Opcode Op = BSE->getOpcode(); - - if (const SymIntExpr *SIE = dyn_cast(BSE)) { - Z3Expr LHS = getZ3SymExpr(SIE->getLHS(), <y, hasComparison); - llvm::APSInt NewRInt; - std::tie(NewRInt, RTy) = fixAPSInt(SIE->getRHS()); - Z3Expr RHS = Solver.fromAPSInt(NewRInt); - return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); - } else if (const IntSymExpr *ISE = dyn_cast(BSE)) { - llvm::APSInt NewLInt; - std::tie(NewLInt, LTy) = fixAPSInt(ISE->getLHS()); - Z3Expr LHS = Solver.fromAPSInt(NewLInt); - Z3Expr RHS = getZ3SymExpr(ISE->getRHS(), &RTy, hasComparison); - return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); - } else if (const SymSymExpr *SSM = dyn_cast(BSE)) { - Z3Expr LHS = getZ3SymExpr(SSM->getLHS(), <y, hasComparison); - Z3Expr RHS = getZ3SymExpr(SSM->getRHS(), &RTy, hasComparison); - return getZ3BinExpr(LHS, LTy, Op, RHS, RTy, RetTy); - } else { - llvm_unreachable("Unsupported BinarySymExpr type!"); - } -} - -Z3Expr Z3ConstraintManager::getZ3BinExpr(const Z3Expr &LHS, QualType LTy, - BinaryOperator::Opcode Op, - const Z3Expr &RHS, QualType RTy, - QualType *RetTy) const { - Z3Expr NewLHS = LHS; - Z3Expr NewRHS = RHS; - doTypeConversion(NewLHS, NewRHS, LTy, RTy); - // Update the return type parameter if the output type has changed. - if (RetTy) { - // A boolean result can be represented as an integer type in C/C++, but at - // this point we only care about the Z3 type. Set it as a boolean type to - // avoid subsequent Z3 errors. - if (BinaryOperator::isComparisonOp(Op) || BinaryOperator::isLogicalOp(Op)) { - ASTContext &Ctx = getBasicVals().getContext(); - *RetTy = Ctx.BoolTy; - } else { - *RetTy = LTy; - } - - // If the two operands are pointers and the operation is a subtraction, the - // result is of type ptrdiff_t, which is signed - if (LTy->isAnyPointerType() && RTy->isAnyPointerType() && Op == BO_Sub) { - *RetTy = getBasicVals().getContext().getPointerDiffType(); - } - } - - return LTy->isRealFloatingType() - ? Solver.fromFloatBinOp(NewLHS, Op, NewRHS) - : Solver.fromBinOp(NewLHS, Op, NewRHS, - LTy->isSignedIntegerOrEnumerationType()); -} - -Z3Expr Z3ConstraintManager::getZ3RangeExpr(SymbolRef Sym, - const llvm::APSInt &From, - const llvm::APSInt &To, - bool InRange) { - // Convert lower bound - QualType FromTy; - llvm::APSInt NewFromInt; - std::tie(NewFromInt, FromTy) = fixAPSInt(From); - Z3Expr FromExp = Solver.fromAPSInt(NewFromInt); - - // Convert symbol - QualType SymTy; - Z3Expr Exp = getZ3Expr(Sym, &SymTy); - - // Construct single (in)equality - if (From == To) - return getZ3BinExpr(Exp, SymTy, InRange ? BO_EQ : BO_NE, FromExp, FromTy, - /*RetTy=*/nullptr); - - QualType ToTy; - llvm::APSInt NewToInt; - std::tie(NewToInt, ToTy) = fixAPSInt(To); - Z3Expr ToExp = Solver.fromAPSInt(NewToInt); - assert(FromTy == ToTy && "Range values have different types!"); - - // Construct two (in)equalities, and a logical and/or - Z3Expr LHS = getZ3BinExpr(Exp, SymTy, InRange ? BO_GE : BO_LT, FromExp, - FromTy, /*RetTy=*/nullptr); - Z3Expr RHS = getZ3BinExpr(Exp, SymTy, InRange ? BO_LE : BO_GT, ToExp, ToTy, - /*RetTy=*/nullptr); - - return Solver.fromBinOp(LHS, InRange ? BO_LAnd : BO_LOr, RHS, - SymTy->isSignedIntegerOrEnumerationType()); -} - -//===------------------------------------------------------------------===// -// Helper functions. -//===------------------------------------------------------------------===// - -QualType Z3ConstraintManager::getAPSIntType(const llvm::APSInt &Int) const { - ASTContext &Ctx = getBasicVals().getContext(); - return Ctx.getIntTypeForBitwidth(Int.getBitWidth(), Int.isSigned()); -} - -std::pair -Z3ConstraintManager::fixAPSInt(const llvm::APSInt &Int) const { - llvm::APSInt NewInt; - - // FIXME: This should be a cast from a 1-bit integer type to a boolean type, - // but the former is not available in Clang. Instead, extend the APSInt - // directly. - if (Int.getBitWidth() == 1 && getAPSIntType(Int).isNull()) { - ASTContext &Ctx = getBasicVals().getContext(); - NewInt = Int.extend(Ctx.getTypeSize(Ctx.BoolTy)); - } else - NewInt = Int; - - return std::make_pair(NewInt, getAPSIntType(NewInt)); -} - -void Z3ConstraintManager::doTypeConversion(Z3Expr &LHS, Z3Expr &RHS, - QualType <y, QualType &RTy) const { - ASTContext &Ctx = getBasicVals().getContext(); - - assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!"); - // Perform type conversion - if (LTy->isIntegralOrEnumerationType() && - RTy->isIntegralOrEnumerationType()) { - if (LTy->isArithmeticType() && RTy->isArithmeticType()) - return doIntTypeConversion(LHS, LTy, RHS, - RTy); - } else if (LTy->isRealFloatingType() || RTy->isRealFloatingType()) { - return doFloatTypeConversion(LHS, LTy, RHS, - RTy); - } else if ((LTy->isAnyPointerType() || RTy->isAnyPointerType()) || - (LTy->isBlockPointerType() || RTy->isBlockPointerType()) || - (LTy->isReferenceType() || RTy->isReferenceType())) { - // TODO: Refactor to Sema::FindCompositePointerType(), and - // Sema::CheckCompareOperands(). - - uint64_t LBitWidth = Ctx.getTypeSize(LTy); - uint64_t RBitWidth = Ctx.getTypeSize(RTy); - - // Cast the non-pointer type to the pointer type. - // TODO: Be more strict about this. - if ((LTy->isAnyPointerType() ^ RTy->isAnyPointerType()) || - (LTy->isBlockPointerType() ^ RTy->isBlockPointerType()) || - (LTy->isReferenceType() ^ RTy->isReferenceType())) { - if (LTy->isNullPtrType() || LTy->isBlockPointerType() || - LTy->isReferenceType()) { - LHS = Solver.fromCast(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - } else { - RHS = Solver.fromCast(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - } - } - - // Cast the void pointer type to the non-void pointer type. - // For void types, this assumes that the casted value is equal to the value - // of the original pointer, and does not account for alignment requirements. - if (LTy->isVoidPointerType() ^ RTy->isVoidPointerType()) { - assert((Ctx.getTypeSize(LTy) == Ctx.getTypeSize(RTy)) && - "Pointer types have different bitwidths!"); - if (RTy->isVoidPointerType()) - RTy = LTy; - else - LTy = RTy; - } - - if (LTy == RTy) - return; - } - - // Fallback: for the solver, assume that these types don't really matter - if ((LTy.getCanonicalType() == RTy.getCanonicalType()) || - (LTy->isObjCObjectPointerType() && RTy->isObjCObjectPointerType())) { - LTy = RTy; - return; - } - - // TODO: Refine behavior for invalid type casts -} - -template -void Z3ConstraintManager::doIntTypeConversion(T &LHS, QualType <y, T &RHS, - QualType &RTy) const { - ASTContext &Ctx = getBasicVals().getContext(); - uint64_t LBitWidth = Ctx.getTypeSize(LTy); - uint64_t RBitWidth = Ctx.getTypeSize(RTy); - - assert(!LTy.isNull() && !RTy.isNull() && "Input type is null!"); - // Always perform integer promotion before checking type equality. - // Otherwise, e.g. (bool) a + (bool) b could trigger a backend assertion - if (LTy->isPromotableIntegerType()) { - QualType NewTy = Ctx.getPromotedIntegerType(LTy); - uint64_t NewBitWidth = Ctx.getTypeSize(NewTy); - LHS = (Solver.*doCast)(LHS, NewTy, NewBitWidth, LTy, LBitWidth); - LTy = NewTy; - LBitWidth = NewBitWidth; - } - if (RTy->isPromotableIntegerType()) { - QualType NewTy = Ctx.getPromotedIntegerType(RTy); - uint64_t NewBitWidth = Ctx.getTypeSize(NewTy); - RHS = (Solver.*doCast)(RHS, NewTy, NewBitWidth, RTy, RBitWidth); - RTy = NewTy; - RBitWidth = NewBitWidth; - } - - if (LTy == RTy) - return; - - // Perform integer type conversion - // Note: Safe to skip updating bitwidth because this must terminate - bool isLSignedTy = LTy->isSignedIntegerOrEnumerationType(); - bool isRSignedTy = RTy->isSignedIntegerOrEnumerationType(); - - int order = Ctx.getIntegerTypeOrder(LTy, RTy); - if (isLSignedTy == isRSignedTy) { - // Same signedness; use the higher-ranked type - if (order == 1) { - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - } else { - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - } - } else if (order != (isLSignedTy ? 1 : -1)) { - // The unsigned type has greater than or equal rank to the - // signed type, so use the unsigned type - if (isRSignedTy) { - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - } else { - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - } - } else if (LBitWidth != RBitWidth) { - // The two types are different widths; if we are here, that - // means the signed type is larger than the unsigned type, so - // use the signed type. - if (isLSignedTy) { - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - } else { - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - } - } else { - // The signed type is higher-ranked than the unsigned type, - // but isn't actually any bigger (like unsigned int and long - // on most 32-bit systems). Use the unsigned type corresponding - // to the signed type. - QualType NewTy = Ctx.getCorrespondingUnsignedType(isLSignedTy ? LTy : RTy); - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = NewTy; - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = NewTy; - } -} - -template -void Z3ConstraintManager::doFloatTypeConversion(T &LHS, QualType <y, T &RHS, - QualType &RTy) const { - ASTContext &Ctx = getBasicVals().getContext(); - - uint64_t LBitWidth = Ctx.getTypeSize(LTy); - uint64_t RBitWidth = Ctx.getTypeSize(RTy); - - // Perform float-point type promotion - if (!LTy->isRealFloatingType()) { - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - LBitWidth = RBitWidth; - } - if (!RTy->isRealFloatingType()) { - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - RBitWidth = LBitWidth; - } - - if (LTy == RTy) - return; - - // If we have two real floating types, convert the smaller operand to the - // bigger result - // Note: Safe to skip updating bitwidth because this must terminate - int order = Ctx.getFloatingTypeOrder(LTy, RTy); - if (order > 0) { - RHS = (Solver.*doCast)(RHS, LTy, LBitWidth, RTy, RBitWidth); - RTy = LTy; - } else if (order == 0) { - LHS = (Solver.*doCast)(LHS, RTy, RBitWidth, LTy, LBitWidth); - LTy = RTy; - } else { - llvm_unreachable("Unsupported floating-point type cast!"); - } -} - -//==------------------------------------------------------------------------==/ -// Pretty-printing. -//==------------------------------------------------------------------------==/ - -void Z3ConstraintManager::print(ProgramStateRef St, raw_ostream &OS, - const char *nl, const char *sep) { - - ConstraintZ3Ty CZ = St->get(); - - OS << nl << sep << "Constraints:"; - for (ConstraintZ3Ty::iterator I = CZ.begin(), E = CZ.end(); I != E; ++I) { - OS << nl << ' ' << I->first << " : "; - I->second.print(OS); - } - OS << nl; -} - #endif std::unique_ptr