diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h --- a/clang/include/clang/AST/ExprCXX.h +++ b/clang/include/clang/AST/ExprCXX.h @@ -4532,7 +4532,9 @@ /// ( expr op ... ) /// ( ... op expr ) /// ( expr op ... op expr ) -class CXXFoldExpr : public Expr { +class CXXFoldExpr final + : public Expr, + private llvm::TrailingObjects { friend class ASTStmtReader; friend class ASTStmtWriter; @@ -4545,24 +4547,64 @@ Stmt *SubExprs[2]; BinaryOperatorKind Opcode; -public: + unsigned NumOverloadCands; + CXXFoldExpr(QualType T, SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS, - SourceLocation RParenLoc, Optional NumExpansions) + SourceLocation RParenLoc, Optional NumExpansions, + llvm::iterator_range OverloadCands) : Expr(CXXFoldExprClass, T, VK_RValue, OK_Ordinary, /*Dependent*/ true, true, true, /*ContainsUnexpandedParameterPack*/ false), LParenLoc(LParenLoc), EllipsisLoc(EllipsisLoc), RParenLoc(RParenLoc), - NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode) { + NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode), + NumOverloadCands( + std::distance(OverloadCands.begin(), OverloadCands.end())) { SubExprs[0] = LHS; SubExprs[1] = RHS; + DeclAccessPair *Results = getTrailingObjects(); + memcpy(Results, OverloadCands.begin().I, + NumOverloadCands * sizeof(DeclAccessPair)); } - CXXFoldExpr(EmptyShell Empty) : Expr(CXXFoldExprClass, Empty) {} + CXXFoldExpr(EmptyShell Empty, unsigned NumOverloadCands) + : Expr(CXXFoldExprClass, Empty), NumOverloadCands(NumOverloadCands) {} + +public: + static CXXFoldExpr * + Create(const ASTContext &Ctx, QualType T, SourceLocation LParenLoc, Expr *LHS, + BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS, + SourceLocation RParenLoc, Optional NumExpansions, + llvm::iterator_range OverloadCands) { + unsigned Size = CXXFoldExpr::totalSizeToAlloc( + std::distance(OverloadCands.begin(), OverloadCands.end())); + void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr)); + return new (Mem) CXXFoldExpr(T, LParenLoc, LHS, Opcode, EllipsisLoc, RHS, + RParenLoc, NumExpansions, OverloadCands); + } + + static CXXFoldExpr *CreateEmpty(const ASTContext &Ctx, + unsigned NumOverloadCands) { + unsigned Size = + CXXFoldExpr::totalSizeToAlloc(NumOverloadCands); + void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr)); + return new (Mem) CXXFoldExpr(EmptyShell(), NumOverloadCands); + } Expr *getLHS() const { return static_cast(SubExprs[0]); } Expr *getRHS() const { return static_cast(SubExprs[1]); } + unsigned getNumOverloadCands() { return NumOverloadCands; } + + unsigned numTrailingObjects(OverloadToken) { + return getNumOverloadCands(); + } + + llvm::iterator_range overloadCands() const { + auto Begin = UnresolvedSetIterator(getTrailingObjects()); + return llvm::make_range(Begin, Begin + NumOverloadCands); + } + /// Does this produce a right-associated sequence of operators? bool isRightFold() const { return getLHS() && getLHS()->containsUnexpandedParameterPack(); diff --git a/clang/include/clang/AST/UnresolvedSet.h b/clang/include/clang/AST/UnresolvedSet.h --- a/clang/include/clang/AST/UnresolvedSet.h +++ b/clang/include/clang/AST/UnresolvedSet.h @@ -33,6 +33,7 @@ std::random_access_iterator_tag, NamedDecl *, std::ptrdiff_t, NamedDecl *, NamedDecl *> { friend class ASTUnresolvedSet; + friend class CXXFoldExpr; friend class OverloadExpr; friend class UnresolvedSetImpl; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -3491,7 +3491,6 @@ bool LookupInSuper(LookupResult &R, CXXRecordDecl *Class); void LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S, - QualType T1, QualType T2, UnresolvedSetImpl &Functions); LabelDecl *LookupOrCreateLabel(IdentifierInfo *II, SourceLocation IdentLoc, @@ -4780,8 +4779,9 @@ public: ExprResult ActOnBinOp(Scope *S, SourceLocation TokLoc, tok::TokenKind Kind, Expr *LHSExpr, Expr *RHSExpr); - ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr); + ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, BinaryOperatorKind Opc, + Expr *LHSExpr, Expr *RHSExpr, + const UnresolvedSetImpl *OverloadCands = nullptr); ExprResult CreateBuiltinBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr); @@ -5440,15 +5440,16 @@ SourceLocation RParenLoc); /// Handle a C++1z fold-expression: ( expr op ... op expr ). - ExprResult ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, + ExprResult ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc, Expr *LHS, tok::TokenKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc); - ExprResult BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, - BinaryOperatorKind Operator, - SourceLocation EllipsisLoc, Expr *RHS, - SourceLocation RParenLoc, - Optional NumExpansions); + ExprResult + BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, + BinaryOperatorKind Operator, SourceLocation EllipsisLoc, + Expr *RHS, SourceLocation RParenLoc, + Optional NumExpansions, + llvm::iterator_range OverloadCands); ExprResult BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc, BinaryOperatorKind Operator); diff --git a/clang/lib/Parse/ParseExpr.cpp b/clang/lib/Parse/ParseExpr.cpp --- a/clang/lib/Parse/ParseExpr.cpp +++ b/clang/lib/Parse/ParseExpr.cpp @@ -2870,8 +2870,9 @@ : diag::ext_fold_expression); T.consumeClose(); - return Actions.ActOnCXXFoldExpr(T.getOpenLocation(), LHS.get(), Kind, - EllipsisLoc, RHS.get(), T.getCloseLocation()); + return Actions.ActOnCXXFoldExpr(getCurScope(), T.getOpenLocation(), LHS.get(), + Kind, EllipsisLoc, RHS.get(), + T.getCloseLocation()); } /// ParseExpressionList - Used for C/C++ (argument-)expression-list. diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -13291,8 +13291,9 @@ /// Build an overloaded binary operator expression in the given scope. static ExprResult BuildOverloadedBinOp(Sema &S, Scope *Sc, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHS, Expr *RHS) { + BinaryOperatorKind Opc, Expr *LHS, + Expr *RHS, + const UnresolvedSetImpl *OverloadCands) { switch (Opc) { case BO_Assign: case BO_DivAssign: @@ -13308,22 +13309,25 @@ break; } + UnresolvedSet<16> Functions; + if (OverloadCands) + Functions.append(OverloadCands->begin(), OverloadCands->end()); + // Find all of the overloaded operators visible from this // point. We perform both an operator-name lookup from the local // scope and an argument-dependent lookup based on the types of // the arguments. - UnresolvedSet<16> Functions; - OverloadedOperatorKind OverOp - = BinaryOperator::getOverloadedOperator(Opc); - if (Sc && OverOp != OO_None && OverOp != OO_Equal) - S.LookupOverloadedOperatorName(OverOp, Sc, LHS->getType(), - RHS->getType(), Functions); + OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc); + if (Sc) { + if (OverOp != OO_None && OverOp != OO_Equal) + S.LookupOverloadedOperatorName(OverOp, Sc, Functions); - // In C++20 onwards, we may have a second operator to look up. - if (S.getLangOpts().CPlusPlus2a) { - if (OverloadedOperatorKind ExtraOp = getRewrittenOverloadedOperator(OverOp)) - S.LookupOverloadedOperatorName(ExtraOp, Sc, LHS->getType(), - RHS->getType(), Functions); + // In C++20 onwards, we may have a second operator to look up. + if (S.getLangOpts().CPlusPlus2a) { + if (OverloadedOperatorKind ExtraOp = + getRewrittenOverloadedOperator(OverOp)) + S.LookupOverloadedOperatorName(ExtraOp, Sc, Functions); + } } // Build the (potentially-overloaded, potentially-dependent) @@ -13332,8 +13336,9 @@ } ExprResult Sema::BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHSExpr, Expr *RHSExpr) { + BinaryOperatorKind Opc, Expr *LHSExpr, + Expr *RHSExpr, + const UnresolvedSetImpl *OverloadCands) { ExprResult LHS, RHS; std::tie(LHS, RHS) = CorrectDelayedTyposInBinOp(*this, Opc, LHSExpr, RHSExpr); if (!LHS.isUsable() || !RHS.isUsable()) @@ -13367,7 +13372,8 @@ if (RHSExpr->isTypeDependent() || RHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); } // If we're instantiating "a.x < b" or "A::x < b" and 'x' names a function @@ -13405,7 +13411,8 @@ if (getLangOpts().CPlusPlus && (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent() || LHSExpr->getType()->isOverloadableType())) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); return CreateBuiltinBinOp(OpLoc, Opc, LHSExpr, RHSExpr); } @@ -13413,7 +13420,8 @@ // Don't resolve overloads if the other type is overloadable. if (getLangOpts().CPlusPlus && pty->getKind() == BuiltinType::Overload && LHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); ExprResult resolvedRHS = CheckPlaceholderExpr(RHSExpr); if (!resolvedRHS.isUsable()) return ExprError(); @@ -13424,13 +13432,15 @@ // If either expression is type-dependent, always build an // overloaded op. if (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); // Otherwise, build an overloaded op if either expression has an // overloadable type. if (LHSExpr->getType()->isOverloadableType() || RHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); } // Build a built-in binary operation. @@ -13746,8 +13756,7 @@ UnresolvedSet<16> Functions; OverloadedOperatorKind OverOp = UnaryOperator::getOverloadedOperator(Opc); if (S && OverOp != OO_None) - LookupOverloadedOperatorName(OverOp, S, Input->getType(), QualType(), - Functions); + LookupOverloadedOperatorName(OverOp, S, Functions); return CreateOverloadedUnaryOp(OpLoc, Opc, Functions, Input); } diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -3037,7 +3037,6 @@ } void Sema::LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S, - QualType T1, QualType T2, UnresolvedSetImpl &Functions) { // C++ [over.match.oper]p3: // -- The set of non-member candidates is the result of the diff --git a/clang/lib/Sema/SemaTemplateVariadic.cpp b/clang/lib/Sema/SemaTemplateVariadic.cpp --- a/clang/lib/Sema/SemaTemplateVariadic.cpp +++ b/clang/lib/Sema/SemaTemplateVariadic.cpp @@ -1153,8 +1153,8 @@ } } -ExprResult Sema::ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, - tok::TokenKind Operator, +ExprResult Sema::ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc, + Expr *LHS, tok::TokenKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc) { // LHS and RHS must be cast-expressions. We allow an arbitrary expression @@ -1195,18 +1195,34 @@ } BinaryOperatorKind Opc = ConvertTokenKindToBinaryOpcode(Operator); + + UnresolvedSet<8> Functions; + OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc); + if (Sc) { + if (OverOp != OO_None && OverOp != OO_Equal) + LookupOverloadedOperatorName(OverOp, Sc, Functions); + + // In C++20 onwards, we may have a second operator to look up. + if (getLangOpts().CPlusPlus2a) { + if (OverloadedOperatorKind ExtraOp = + getRewrittenOverloadedOperator(OverOp)) + LookupOverloadedOperatorName(ExtraOp, Sc, Functions); + } + } + return BuildCXXFoldExpr(LParenLoc, LHS, Opc, EllipsisLoc, RHS, RParenLoc, - None); + None, + llvm::make_range(Functions.begin(), Functions.end())); } -ExprResult Sema::BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, - BinaryOperatorKind Operator, - SourceLocation EllipsisLoc, Expr *RHS, - SourceLocation RParenLoc, - Optional NumExpansions) { - return new (Context) CXXFoldExpr(Context.DependentTy, LParenLoc, LHS, - Operator, EllipsisLoc, RHS, RParenLoc, - NumExpansions); +ExprResult Sema::BuildCXXFoldExpr( + SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Operator, + SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc, + Optional NumExpansions, + llvm::iterator_range OverloadCands) { + return CXXFoldExpr::Create(Context, Context.DependentTy, LParenLoc, LHS, + Operator, EllipsisLoc, RHS, RParenLoc, + NumExpansions, OverloadCands); } ExprResult Sema::BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2347,10 +2347,12 @@ /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildBinaryOperator(SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHS, Expr *RHS) { - return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS); + ExprResult + RebuildBinaryOperator(SourceLocation OpLoc, BinaryOperatorKind Opc, Expr *LHS, + Expr *RHS, + const UnresolvedSetImpl *OverloadCands = nullptr) { + return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS, + OverloadCands); } /// Build a new rewritten operator expression. @@ -3317,13 +3319,14 @@ /// /// By default, performs semantic analysis in order to build a new fold /// expression. - ExprResult RebuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, - BinaryOperatorKind Operator, - SourceLocation EllipsisLoc, Expr *RHS, - SourceLocation RParenLoc, - Optional NumExpansions) { + ExprResult RebuildCXXFoldExpr( + SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Operator, + SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc, + Optional NumExpansions, + llvm::iterator_range OverloadCands) { return getSema().BuildCXXFoldExpr(LParenLoc, LHS, Operator, EllipsisLoc, - RHS, RParenLoc, NumExpansions); + RHS, RParenLoc, NumExpansions, + OverloadCands); } /// Build an empty C++1z fold-expression with the given operator. @@ -12191,7 +12194,7 @@ return getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), LHS.get(), E->getOperator(), E->getEllipsisLoc(), - RHS.get(), E->getEndLoc(), NumExpansions); + RHS.get(), E->getEndLoc(), NumExpansions, E->overloadCands()); } // The transform has determined that we should perform an elementwise @@ -12212,11 +12215,13 @@ Result = getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), Out.get(), E->getOperator(), E->getEllipsisLoc(), - Result.get(), E->getEndLoc(), OrigNumExpansions); + Result.get(), E->getEndLoc(), OrigNumExpansions, E->overloadCands()); if (Result.isInvalid()) return true; } + UnresolvedSet<8> OverloadCands; + OverloadCands.append(E->overloadCands().begin(), E->overloadCands().end()); for (unsigned I = 0; I != *NumExpansions; ++I) { Sema::ArgumentPackSubstitutionIndexRAII SubstIndex( getSema(), LeftFold ? I : *NumExpansions - I - 1); @@ -12230,13 +12235,13 @@ E->getBeginLoc(), LeftFold ? Result.get() : Out.get(), E->getOperator(), E->getEllipsisLoc(), LeftFold ? Out.get() : Result.get(), E->getEndLoc(), - OrigNumExpansions); + OrigNumExpansions, E->overloadCands()); } else if (Result.isUsable()) { // We've got down to a single element; build a binary operator. Result = getDerived().RebuildBinaryOperator( E->getEllipsisLoc(), E->getOperator(), LeftFold ? Result.get() : Out.get(), - LeftFold ? Out.get() : Result.get()); + LeftFold ? Out.get() : Result.get(), &OverloadCands); } else Result = Out; @@ -12255,7 +12260,7 @@ Result = getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), Result.get(), E->getOperator(), E->getEllipsisLoc(), - Out.get(), E->getEndLoc(), OrigNumExpansions); + Out.get(), E->getEndLoc(), OrigNumExpansions, E->overloadCands()); if (Result.isInvalid()) return true; } diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -1908,6 +1908,7 @@ void ASTStmtReader::VisitCXXFoldExpr(CXXFoldExpr *E) { VisitExpr(E); + unsigned NumOverloadCands = Record.readInt(); E->LParenLoc = ReadSourceLocation(); E->EllipsisLoc = ReadSourceLocation(); E->RParenLoc = ReadSourceLocation(); @@ -1915,6 +1916,13 @@ E->SubExprs[0] = Record.readSubExpr(); E->SubExprs[1] = Record.readSubExpr(); E->Opcode = (BinaryOperatorKind)Record.readInt(); + + DeclAccessPair *OverloadCands = E->getTrailingObjects(); + for (unsigned I = 0; I != NumOverloadCands; ++I) { + auto *D = ReadDeclAs(); + auto AS = (AccessSpecifier)Record.readInt(); + OverloadCands[I].set(D, AS); + } } void ASTStmtReader::VisitOpaqueValueExpr(OpaqueValueExpr *E) { @@ -3479,7 +3487,8 @@ break; case EXPR_CXX_FOLD: - S = new (Context) CXXFoldExpr(Empty); + S = CXXFoldExpr::CreateEmpty(Context, + Record[ASTStmtReader::NumExprFields]); break; case EXPR_OPAQUE_VALUE: diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -1843,6 +1843,7 @@ void ASTStmtWriter::VisitCXXFoldExpr(CXXFoldExpr *E) { VisitExpr(E); + Record.push_back(E->getNumOverloadCands()); Record.AddSourceLocation(E->LParenLoc); Record.AddSourceLocation(E->EllipsisLoc); Record.AddSourceLocation(E->RParenLoc); @@ -1850,6 +1851,12 @@ Record.AddStmt(E->SubExprs[0]); Record.AddStmt(E->SubExprs[1]); Record.push_back(E->Opcode); + for (UnresolvedSetIterator I = E->overloadCands().begin(), + End = E->overloadCands().end(); + I != End; ++I) { + Record.AddDeclRef(I.getDecl()); + Record.push_back(I.getAccess()); + } Code = serialization::EXPR_CXX_FOLD; } diff --git a/clang/test/SemaTemplate/cxx1z-fold-expressions.cpp b/clang/test/SemaTemplate/cxx1z-fold-expressions.cpp --- a/clang/test/SemaTemplate/cxx1z-fold-expressions.cpp +++ b/clang/test/SemaTemplate/cxx1z-fold-expressions.cpp @@ -79,6 +79,36 @@ static_assert(&apply(a, &A::b, &A::B::c, &A::B::C::d, &A::B::C::D::e) == &a.b.c.d.e); #if __cplusplus > 201703L + +namespace N { + + struct Bool { + constexpr Bool(const bool& b) : b(b) {} + bool b; + }; + +} + +constexpr bool operator==(const N::Bool& b1, const N::Bool& b2) { return b1.b == b2.b; } +constexpr int operator<=>(const N::Bool& b1, const N::Bool& b2) { return b1.b - b2.b; } + +template constexpr auto fold_eq(T ...t) { return (t == ...); } +template constexpr auto fold_neq(T ...t) { return (t != ...); } +template constexpr auto fold_le(T ...t) { return (t < ...); } +template constexpr auto fold_leq(T ...t) { return (t <= ...); } +template constexpr auto fold_ge(T ...t) { return (t > ...); } +template constexpr auto fold_geq(T ...t) { return (t >= ...); } + +static_assert(fold_eq(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true})); +static_assert(!fold_eq(N::Bool{true}, N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{true})); +static_assert(fold_neq(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true})); +static_assert(!fold_neq(N::Bool{true}, N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{true})); + +static_assert(!fold_le(N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true})); +static_assert(fold_leq(N::Bool{false}, N::Bool{true}, N::Bool{true}, N::Bool{true}, N::Bool{true})); +static_assert(fold_ge(N::Bool{true}, N::Bool{false}, N::Bool{true}, N::Bool{false}, N::Bool{false})); +static_assert(!fold_ge(N::Bool{false}, N::Bool{false}, N::Bool{true}, N::Bool{false}, N::Bool{false})); + // The <=> operator is unique among binary operators in not being a // fold-operator. // FIXME: This diagnostic is not great. @@ -102,3 +132,49 @@ Sum<1>::type<1, 2> x; // expected-note {{instantiation of}} } + +namespace N { + + struct A { int i; }; + struct B { int i; }; + + constexpr B operator+(const B& a, const B& b) { return { a.i + b.i }; } + +} + +struct C { int i; }; + +constexpr C operator+(const C& a, const C& b) { return { a.i + b.i }; } +constexpr N::A operator+(const N::A& a, const N::A& b) { return { a.i + b.i }; } + +template constexpr auto custom_fold(T1 t1, T2 ...t2) { + return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2); +} + +static_assert(custom_fold(N::A{1}, N::A{2}, N::A{3}, N::A{4}, N::A{5}).i == 58); +static_assert(custom_fold(N::B{1}, N::B{2}, N::B{3}, N::B{4}, N::B{5}).i == 58); +static_assert(custom_fold(C{1}, C{2}, C{3}, C{4}, C{5}).i == 58); + +template constexpr auto func_fold( + decltype((T{ I2 } + ...) + (... + T{ I2 }) + (T{ I2 } + ... + T{ I1 }) + (T{ I1 } + ... + T{ I2 })) t) { + return t.i; +} + +static_assert(func_fold(N::A{ 42 }) == 42); +static_assert(func_fold(N::B{ 42 }) == 42); +static_assert(func_fold(C{ 42 }) == 42); + +struct D { int i; }; + +namespace N { + + constexpr D operator+(const D& a, const D& b) { return { a.i + b.i }; } + +} + +template constexpr auto custom_fold_using(T1 t1, T2 ...t2) { + using N::operator+; + return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2); +} + +static_assert(custom_fold_using(D{1}, D{2}, D{3}, D{4}, D{5}).i == 58);