diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -64,6 +64,7 @@ bool VisitIntegerLiteral(const IntegerLiteral *E); bool VisitParenExpr(const ParenExpr *E); bool VisitBinaryOperator(const BinaryOperator *E); + bool VisitPointerArithBinOp(const BinaryOperator *E); bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E); bool VisitCallExpr(const CallExpr *E); bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *E); diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -190,67 +190,120 @@ // Typecheck the args. Optional LT = classify(LHS->getType()); Optional RT = classify(RHS->getType()); - if (!LT || !RT) { + Optional T = classify(BO->getType()); + if (!LT || !RT || !T) { return this->bail(BO); } - if (Optional T = classify(BO->getType())) { - if (!visit(LHS)) + auto Discard = [this, T, BO](bool Result) { + if (!Result) return false; - if (!visit(RHS)) + return DiscardResult ? this->emitPop(*T, BO) : true; + }; + + // Pointer arithmetic special case. + if (BO->getOpcode() == BO_Add || BO->getOpcode() == BO_Sub) { + if (*T == PT_Ptr || (*LT == PT_Ptr && *RT == PT_Ptr)) + return this->VisitPointerArithBinOp(BO); + } + + if (!visit(LHS) || !visit(RHS)) + return false; + + switch (BO->getOpcode()) { + case BO_EQ: + return Discard(this->emitEQ(*LT, BO)); + case BO_NE: + return Discard(this->emitNE(*LT, BO)); + case BO_LT: + return Discard(this->emitLT(*LT, BO)); + case BO_LE: + return Discard(this->emitLE(*LT, BO)); + case BO_GT: + return Discard(this->emitGT(*LT, BO)); + case BO_GE: + return Discard(this->emitGE(*LT, BO)); + case BO_Sub: + return Discard(this->emitSub(*T, BO)); + case BO_Add: + return Discard(this->emitAdd(*T, BO)); + case BO_Mul: + return Discard(this->emitMul(*T, BO)); + case BO_Rem: + return Discard(this->emitRem(*T, BO)); + case BO_Div: + return Discard(this->emitDiv(*T, BO)); + case BO_Assign: + if (!this->emitStore(*T, BO)) return false; + return DiscardResult ? this->emitPopPtr(BO) : true; + case BO_And: + return Discard(this->emitBitAnd(*T, BO)); + case BO_Or: + return Discard(this->emitBitOr(*T, BO)); + case BO_Shl: + return Discard(this->emitShl(*LT, *RT, BO)); + case BO_Shr: + return Discard(this->emitShr(*LT, *RT, BO)); + case BO_Xor: + return Discard(this->emitBitXor(*T, BO)); + case BO_LAnd: + case BO_LOr: + default: + return this->bail(BO); + } - auto Discard = [this, T, BO](bool Result) { - if (!Result) - return false; - return DiscardResult ? this->emitPop(*T, BO) : true; - }; - - switch (BO->getOpcode()) { - case BO_EQ: - return Discard(this->emitEQ(*LT, BO)); - case BO_NE: - return Discard(this->emitNE(*LT, BO)); - case BO_LT: - return Discard(this->emitLT(*LT, BO)); - case BO_LE: - return Discard(this->emitLE(*LT, BO)); - case BO_GT: - return Discard(this->emitGT(*LT, BO)); - case BO_GE: - return Discard(this->emitGE(*LT, BO)); - case BO_Sub: - return Discard(this->emitSub(*T, BO)); - case BO_Add: - return Discard(this->emitAdd(*T, BO)); - case BO_Mul: - return Discard(this->emitMul(*T, BO)); - case BO_Rem: - return Discard(this->emitRem(*T, BO)); - case BO_Div: - return Discard(this->emitDiv(*T, BO)); - case BO_Assign: - if (!this->emitStore(*T, BO)) - return false; - return DiscardResult ? this->emitPopPtr(BO) : true; - case BO_And: - return Discard(this->emitBitAnd(*T, BO)); - case BO_Or: - return Discard(this->emitBitOr(*T, BO)); - case BO_Shl: - return Discard(this->emitShl(*LT, *RT, BO)); - case BO_Shr: - return Discard(this->emitShr(*LT, *RT, BO)); - case BO_Xor: - return Discard(this->emitBitXor(*T, BO)); - case BO_LAnd: - case BO_LOr: - default: - return this->bail(BO); - } + llvm_unreachable("Unhandled binary op"); +} + +/// Perform addition/subtraction of a pointer and an integer or +/// subtraction of two pointers. +template +bool ByteCodeExprGen::VisitPointerArithBinOp(const BinaryOperator *E) { + BinaryOperatorKind Op = E->getOpcode(); + const Expr *LHS = E->getLHS(); + const Expr *RHS = E->getRHS(); + + if ((Op != BO_Add && Op != BO_Sub) || + (!LHS->getType()->isPointerType() && !RHS->getType()->isPointerType())) + return false; + + Optional LT = classify(LHS); + Optional RT = classify(RHS); + + if (!LT || !RT) + return false; + + if (LHS->getType()->isPointerType() && RHS->getType()->isPointerType()) { + if (Op != BO_Sub) + return false; + + assert(E->getType()->isIntegerType()); + if (!visit(RHS) || !visit(LHS)) + return false; + + return this->emitSubPtr(classifyPrim(E->getType()), E); + } + + PrimType OffsetType; + if (LHS->getType()->isIntegerType()) { + if (!visit(RHS) || !visit(LHS)) + return false; + OffsetType = *LT; + } else if (RHS->getType()->isIntegerType()) { + if (!visit(LHS) || !visit(RHS)) + return false; + OffsetType = *RT; + } else { + return false; } - return this->bail(BO); + if (Op == BO_Add) + return this->emitAddOffset(OffsetType, E); + else if (Op == BO_Sub) + return this->emitSubOffset(OffsetType, E); + + return this->bail(E); } template diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -466,6 +466,16 @@ } else { unsigned VL = LHS.getByteOffset(); unsigned VR = RHS.getByteOffset(); + + // In our Pointer class, a pointer to an array and a pointer to the first + // element in the same array are NOT equal. They have the same Base value, + // but a different Offset. This is a pretty rare case, so we fix this here + // by comparing pointers to the first elements. + if (LHS.inArray() && LHS.isRoot()) + VL = LHS.atIndex(0).getByteOffset(); + if (RHS.inArray() && RHS.isRoot()) + VR = RHS.atIndex(0).getByteOffset(); + S.Stk.push(BoolT::from(Fn(Compare(VL, VR)))); return true; } @@ -991,23 +1001,25 @@ // Fetch the pointer and the offset. const T &Offset = S.Stk.pop(); const Pointer &Ptr = S.Stk.pop(); - if (!CheckNull(S, OpPC, Ptr, CSK_ArrayIndex)) - return false; + if (!CheckRange(S, OpPC, Ptr, CSK_ArrayToPointer)) return false; - // Get a version of the index comparable to the type. - T Index = T::from(Ptr.getIndex(), Offset.bitWidth()); - // A zero offset does not change the pointer, but in the case of an array - // it has to be adjusted to point to the first element instead of the array. + // A zero offset does not change the pointer. if (Offset.isZero()) { - S.Stk.push(Index.isZero() ? Ptr.atIndex(0) : Ptr); + S.Stk.push(Ptr); return true; } + + if (!CheckNull(S, OpPC, Ptr, CSK_ArrayIndex)) + return false; + // Arrays of unknown bounds cannot have pointers into them. if (!CheckArray(S, OpPC, Ptr)) return false; + // Get a version of the index comparable to the type. + T Index = T::from(Ptr.getIndex(), Offset.bitWidth()); // Compute the largest index into the array. unsigned MaxIndex = Ptr.getNumElems(); @@ -1061,6 +1073,23 @@ return OffsetHelper(S, OpPC); } +/// 1) Pops a Pointer from the stack. +/// 2) Pops another Pointer from the stack. +/// 3) Pushes the different of the indices of the two pointers on the stack. +template ::T> +inline bool SubPtr(InterpState &S, CodePtr OpPC) { + const Pointer &LHS = S.Stk.pop(); + const Pointer &RHS = S.Stk.pop(); + + if (!Pointer::hasSameArray(LHS, RHS)) { + // TODO: Diagnose. + return false; + } + + T A = T::from(LHS.getIndex()); + T B = T::from(RHS.getIndex()); + return AddSubMulHelper(S, OpPC, A.bitWidth(), A, B); +} //===----------------------------------------------------------------------===// // Destroy diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -390,6 +390,12 @@ // [Pointer, Integral] -> [Pointer] def SubOffset : AluOpcode; +// Pointer, Pointer] - [Integral] +def SubPtr : Opcode { + let Types = [IntegerTypeClass]; + let HasGroup = 1; +} + //===----------------------------------------------------------------------===// // Binary operators. //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/Interp/Pointer.cpp b/clang/lib/AST/Interp/Pointer.cpp --- a/clang/lib/AST/Interp/Pointer.cpp +++ b/clang/lib/AST/Interp/Pointer.cpp @@ -202,5 +202,5 @@ } bool Pointer::hasSameArray(const Pointer &A, const Pointer &B) { - return A.Base == B.Base && A.getFieldDesc()->IsArray; + return hasSameBase(A, B) && A.Base == B.Base && A.getFieldDesc()->IsArray; } diff --git a/clang/test/AST/Interp/arrays.cpp b/clang/test/AST/Interp/arrays.cpp --- a/clang/test/AST/Interp/arrays.cpp +++ b/clang/test/AST/Interp/arrays.cpp @@ -37,6 +37,60 @@ static_assert(getElement(1) == 4, ""); static_assert(getElement(5) == 36, ""); +constexpr int data[] = {5, 4, 3, 2, 1}; +constexpr int getElement(const int *Arr, int index) { + return *(Arr + index); +} + +static_assert(getElement(data, 1) == 4, ""); +static_assert(getElement(data, 4) == 1, ""); + +constexpr int getElementFromEnd(const int *Arr, int size, int index) { + return *(Arr + size - index - 1); +} +static_assert(getElementFromEnd(data, 5, 0) == 1, ""); +static_assert(getElementFromEnd(data, 5, 4) == 5, ""); + + +constexpr static int arr[2] = {1,2}; +constexpr static int arr2[2] = {3,4}; +constexpr int *p1 = nullptr; +constexpr int *p2 = p1 + 1; // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot perform pointer arithmetic on null pointer}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot perform pointer arithmetic on null pointer}} +constexpr int *p3 = p1 + 0; +constexpr int *p4 = p1 - 0; +constexpr int *p5 = 0 + p1; +constexpr int *p6 = 0 - p1; // expected-error {{invalid operands to binary expression}} \ + // ref-error {{invalid operands to binary expression}} + +constexpr int const * ap1 = &arr[0]; +constexpr int const * ap2 = ap1 + 3; // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot refer to element 3 of array of 2}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot refer to element 3 of array of 2}} + +constexpr auto ap3 = arr - 1; // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot refer to element -1}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot refer to element -1}} +constexpr int k1 = &arr[1] - &arr[0]; +static_assert(k1 == 1, ""); +static_assert((&arr[0] - &arr[1]) == -1, ""); + +constexpr int k2 = &arr2[1] - &arr[0]; // expected-error {{must be initialized by a constant expression}} \ + // ref-error {{must be initialized by a constant expression}} + +static_assert((arr + 0) == arr, ""); +static_assert(&arr[0] == arr, ""); +static_assert(*(&arr[0]) == 1, ""); +static_assert(*(&arr[1]) == 2, ""); + +constexpr const int *OOB = (arr + 3) - 3; // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot refer to element 3 of array of 2}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot refer to element 3 of array of 2}} template constexpr T getElementOf(T* array, int i) { @@ -52,7 +106,6 @@ static_assert(getElementOfArray(foo[2], 3) == &m, ""); -constexpr int data[] = {5, 4, 3, 2, 1}; static_assert(data[0] == 4, ""); // expected-error{{failed}} \ // expected-note{{5 == 4}} \ // ref-error{{failed}} \