Index: clang/include/clang/AST/ExprCXX.h =================================================================== --- clang/include/clang/AST/ExprCXX.h +++ clang/include/clang/AST/ExprCXX.h @@ -4438,7 +4438,9 @@ /// ( expr op ... ) /// ( ... op expr ) /// ( expr op ... op expr ) -class CXXFoldExpr : public Expr { +class CXXFoldExpr final + : public Expr, + private llvm::TrailingObjects { friend class ASTStmtReader; friend class ASTStmtWriter; @@ -4451,24 +4453,70 @@ Stmt *SubExprs[2]; BinaryOperatorKind Opcode; -public: + unsigned NumOverloadCands; + CXXFoldExpr(QualType T, SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS, - SourceLocation RParenLoc, Optional NumExpansions) + SourceLocation RParenLoc, Optional NumExpansions, + UnresolvedSetIterator OverloadCandsBegin, + UnresolvedSetIterator OverloadCandsEnd) : Expr(CXXFoldExprClass, T, VK_RValue, OK_Ordinary, /*Dependent*/ true, true, true, /*ContainsUnexpandedParameterPack*/ false), LParenLoc(LParenLoc), EllipsisLoc(EllipsisLoc), RParenLoc(RParenLoc), - NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode) { + NumExpansions(NumExpansions ? *NumExpansions + 1 : 0), Opcode(Opcode), + NumOverloadCands(OverloadCandsEnd - OverloadCandsBegin) { SubExprs[0] = LHS; SubExprs[1] = RHS; + DeclAccessPair *Results = getTrailingObjects(); + memcpy(Results, OverloadCandsBegin.I, + NumOverloadCands * sizeof(DeclAccessPair)); } - CXXFoldExpr(EmptyShell Empty) : Expr(CXXFoldExprClass, Empty) {} + CXXFoldExpr(EmptyShell Empty, unsigned NumOverloadCands) + : Expr(CXXFoldExprClass, Empty), NumOverloadCands(NumOverloadCands) {} + +public: + static CXXFoldExpr * + Create(const ASTContext &Ctx, QualType T, SourceLocation LParenLoc, Expr *LHS, + BinaryOperatorKind Opcode, SourceLocation EllipsisLoc, Expr *RHS, + SourceLocation RParenLoc, Optional NumExpansions, + UnresolvedSetIterator OverloadCandsBegin, + UnresolvedSetIterator OverloadCandsEnd) { + unsigned Size = CXXFoldExpr::totalSizeToAlloc( + OverloadCandsEnd - OverloadCandsBegin); + void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr)); + return new (Mem) + CXXFoldExpr(T, LParenLoc, LHS, Opcode, EllipsisLoc, RHS, RParenLoc, + NumExpansions, OverloadCandsBegin, OverloadCandsEnd); + } + + static CXXFoldExpr *CreateEmpty(const ASTContext &Ctx, + unsigned NumOverloadCands) { + unsigned Size = + CXXFoldExpr::totalSizeToAlloc(NumOverloadCands); + void *Mem = Ctx.Allocate(Size, alignof(CXXFoldExpr)); + return new (Mem) CXXFoldExpr(EmptyShell(), NumOverloadCands); + } Expr *getLHS() const { return static_cast(SubExprs[0]); } Expr *getRHS() const { return static_cast(SubExprs[1]); } + unsigned getNumOverloadCands() { return NumOverloadCands; } + + unsigned numTrailingObjects(OverloadToken) { + return getNumOverloadCands(); + } + + UnresolvedSetIterator overloadCandsBegin() { + return UnresolvedSetIterator(getTrailingObjects()); + } + + UnresolvedSetIterator overloadCandsEnd() { + return UnresolvedSetIterator(getTrailingObjects() + + NumOverloadCands); + } + /// Does this produce a right-associated sequence of operators? bool isRightFold() const { return getLHS() && getLHS()->containsUnexpandedParameterPack(); Index: clang/include/clang/AST/UnresolvedSet.h =================================================================== --- clang/include/clang/AST/UnresolvedSet.h +++ clang/include/clang/AST/UnresolvedSet.h @@ -33,6 +33,7 @@ std::random_access_iterator_tag, NamedDecl *, std::ptrdiff_t, NamedDecl *, NamedDecl *> { friend class ASTUnresolvedSet; + friend class CXXFoldExpr; friend class OverloadExpr; friend class UnresolvedSetImpl; Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -3391,7 +3391,6 @@ bool LookupInSuper(LookupResult &R, CXXRecordDecl *Class); void LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S, - QualType T1, QualType T2, UnresolvedSetImpl &Functions); LabelDecl *LookupOrCreateLabel(IdentifierInfo *II, SourceLocation IdentLoc, @@ -4654,8 +4653,9 @@ public: ExprResult ActOnBinOp(Scope *S, SourceLocation TokLoc, tok::TokenKind Kind, Expr *LHSExpr, Expr *RHSExpr); - ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr); + ExprResult BuildBinOp(Scope *S, SourceLocation OpLoc, BinaryOperatorKind Opc, + Expr *LHSExpr, Expr *RHSExpr, + const UnresolvedSetImpl *OverloadCands = nullptr); ExprResult CreateBuiltinBinOp(SourceLocation OpLoc, BinaryOperatorKind Opc, Expr *LHSExpr, Expr *RHSExpr); @@ -5314,7 +5314,7 @@ SourceLocation RParenLoc); /// Handle a C++1z fold-expression: ( expr op ... op expr ). - ExprResult ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, + ExprResult ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc, Expr *LHS, tok::TokenKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc); @@ -5322,7 +5322,9 @@ BinaryOperatorKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc, - Optional NumExpansions); + Optional NumExpansions, + UnresolvedSetIterator OverloadCandsBegin, + UnresolvedSetIterator OverloadCandsEnd); ExprResult BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc, BinaryOperatorKind Operator); Index: clang/lib/Parse/ParseExpr.cpp =================================================================== --- clang/lib/Parse/ParseExpr.cpp +++ clang/lib/Parse/ParseExpr.cpp @@ -2862,8 +2862,9 @@ : diag::ext_fold_expression); T.consumeClose(); - return Actions.ActOnCXXFoldExpr(T.getOpenLocation(), LHS.get(), Kind, - EllipsisLoc, RHS.get(), T.getCloseLocation()); + return Actions.ActOnCXXFoldExpr(getCurScope(), T.getOpenLocation(), LHS.get(), + Kind, EllipsisLoc, RHS.get(), + T.getCloseLocation()); } /// ParseExpressionList - Used for C/C++ (argument-)expression-list. Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -13161,8 +13161,9 @@ /// Build an overloaded binary operator expression in the given scope. static ExprResult BuildOverloadedBinOp(Sema &S, Scope *Sc, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHS, Expr *RHS) { + BinaryOperatorKind Opc, Expr *LHS, + Expr *RHS, + const UnresolvedSetImpl *OverloadCands) { switch (Opc) { case BO_Assign: case BO_DivAssign: @@ -13186,8 +13187,11 @@ OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc); if (Sc && OverOp != OO_None && OverOp != OO_Equal) - S.LookupOverloadedOperatorName(OverOp, Sc, LHS->getType(), - RHS->getType(), Functions); + S.LookupOverloadedOperatorName(OverOp, Sc, Functions); + + if (OverloadCands) { + Functions.append(OverloadCands->begin(), OverloadCands->end()); + } // Build the (potentially-overloaded, potentially-dependent) // binary operation. @@ -13195,8 +13199,9 @@ } ExprResult Sema::BuildBinOp(Scope *S, SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHSExpr, Expr *RHSExpr) { + BinaryOperatorKind Opc, Expr *LHSExpr, + Expr *RHSExpr, + const UnresolvedSetImpl *OverloadCands) { ExprResult LHS, RHS; std::tie(LHS, RHS) = CorrectDelayedTyposInBinOp(*this, Opc, LHSExpr, RHSExpr); if (!LHS.isUsable() || !RHS.isUsable()) @@ -13230,7 +13235,8 @@ if (RHSExpr->isTypeDependent() || RHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); } // If we're instantiating "a.x < b" or "A::x < b" and 'x' names a function @@ -13268,7 +13274,8 @@ if (getLangOpts().CPlusPlus && (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent() || LHSExpr->getType()->isOverloadableType())) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); return CreateBuiltinBinOp(OpLoc, Opc, LHSExpr, RHSExpr); } @@ -13276,7 +13283,8 @@ // Don't resolve overloads if the other type is overloadable. if (getLangOpts().CPlusPlus && pty->getKind() == BuiltinType::Overload && LHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); ExprResult resolvedRHS = CheckPlaceholderExpr(RHSExpr); if (!resolvedRHS.isUsable()) return ExprError(); @@ -13287,13 +13295,15 @@ // If either expression is type-dependent, always build an // overloaded op. if (LHSExpr->isTypeDependent() || RHSExpr->isTypeDependent()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); // Otherwise, build an overloaded op if either expression has an // overloadable type. if (LHSExpr->getType()->isOverloadableType() || RHSExpr->getType()->isOverloadableType()) - return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr); + return BuildOverloadedBinOp(*this, S, OpLoc, Opc, LHSExpr, RHSExpr, + OverloadCands); } // Build a built-in binary operation. @@ -13610,8 +13620,7 @@ UnresolvedSet<16> Functions; OverloadedOperatorKind OverOp = UnaryOperator::getOverloadedOperator(Opc); if (S && OverOp != OO_None) - LookupOverloadedOperatorName(OverOp, S, Input->getType(), QualType(), - Functions); + LookupOverloadedOperatorName(OverOp, S, Functions); return CreateOverloadedUnaryOp(OpLoc, Opc, Functions, Input); } Index: clang/lib/Sema/SemaLookup.cpp =================================================================== --- clang/lib/Sema/SemaLookup.cpp +++ clang/lib/Sema/SemaLookup.cpp @@ -3030,7 +3030,6 @@ } void Sema::LookupOverloadedOperatorName(OverloadedOperatorKind Op, Scope *S, - QualType T1, QualType T2, UnresolvedSetImpl &Functions) { // C++ [over.match.oper]p3: // -- The set of non-member candidates is the result of the Index: clang/lib/Sema/SemaTemplateVariadic.cpp =================================================================== --- clang/lib/Sema/SemaTemplateVariadic.cpp +++ clang/lib/Sema/SemaTemplateVariadic.cpp @@ -1153,8 +1153,8 @@ } } -ExprResult Sema::ActOnCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, - tok::TokenKind Operator, +ExprResult Sema::ActOnCXXFoldExpr(Scope *Sc, SourceLocation LParenLoc, + Expr *LHS, tok::TokenKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc) { // LHS and RHS must be cast-expressions. We allow an arbitrary expression @@ -1195,18 +1195,26 @@ } BinaryOperatorKind Opc = ConvertTokenKindToBinaryOpcode(Operator); + + UnresolvedSet<8> Functions; + OverloadedOperatorKind OverOp = BinaryOperator::getOverloadedOperator(Opc); + if (Sc && OverOp != OO_None && OverOp != OO_Equal) + LookupOverloadedOperatorName(OverOp, Sc, Functions); + return BuildCXXFoldExpr(LParenLoc, LHS, Opc, EllipsisLoc, RHS, RParenLoc, - None); + None, Functions.begin(), Functions.end()); } ExprResult Sema::BuildCXXFoldExpr(SourceLocation LParenLoc, Expr *LHS, BinaryOperatorKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc, - Optional NumExpansions) { - return new (Context) CXXFoldExpr(Context.DependentTy, LParenLoc, LHS, - Operator, EllipsisLoc, RHS, RParenLoc, - NumExpansions); + Optional NumExpansions, + UnresolvedSetIterator OverloadCandsBegin, + UnresolvedSetIterator OverloadCandsEnd) { + return CXXFoldExpr::Create( + Context, Context.DependentTy, LParenLoc, LHS, Operator, EllipsisLoc, RHS, + RParenLoc, NumExpansions, OverloadCandsBegin, OverloadCandsEnd); } ExprResult Sema::BuildEmptyCXXFoldExpr(SourceLocation EllipsisLoc, Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -2346,10 +2346,12 @@ /// /// By default, performs semantic analysis to build the new expression. /// Subclasses may override this routine to provide different behavior. - ExprResult RebuildBinaryOperator(SourceLocation OpLoc, - BinaryOperatorKind Opc, - Expr *LHS, Expr *RHS) { - return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS); + ExprResult + RebuildBinaryOperator(SourceLocation OpLoc, BinaryOperatorKind Opc, Expr *LHS, + Expr *RHS, + const UnresolvedSetImpl *OverloadCands = nullptr) { + return getSema().BuildBinOp(/*Scope=*/nullptr, OpLoc, Opc, LHS, RHS, + OverloadCands); } /// Build a new conditional operator expression. @@ -3290,9 +3292,12 @@ BinaryOperatorKind Operator, SourceLocation EllipsisLoc, Expr *RHS, SourceLocation RParenLoc, - Optional NumExpansions) { + Optional NumExpansions, + UnresolvedSetIterator OverloadCandsBegin, + UnresolvedSetIterator OverloadCandsEnd) { return getSema().BuildCXXFoldExpr(LParenLoc, LHS, Operator, EllipsisLoc, - RHS, RParenLoc, NumExpansions); + RHS, RParenLoc, NumExpansions, + OverloadCandsBegin, OverloadCandsEnd); } /// Build an empty C++1z fold-expression with the given operator. @@ -12058,7 +12063,8 @@ return getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), LHS.get(), E->getOperator(), E->getEllipsisLoc(), - RHS.get(), E->getEndLoc(), NumExpansions); + RHS.get(), E->getEndLoc(), NumExpansions, E->overloadCandsBegin(), + E->overloadCandsEnd()); } // The transform has determined that we should perform an elementwise @@ -12079,11 +12085,14 @@ Result = getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), Out.get(), E->getOperator(), E->getEllipsisLoc(), - Result.get(), E->getEndLoc(), OrigNumExpansions); + Result.get(), E->getEndLoc(), OrigNumExpansions, + E->overloadCandsBegin(), E->overloadCandsEnd()); if (Result.isInvalid()) return true; } + UnresolvedSet<8> OverloadCands; + OverloadCands.append(E->overloadCandsBegin(), E->overloadCandsEnd()); for (unsigned I = 0; I != *NumExpansions; ++I) { Sema::ArgumentPackSubstitutionIndexRAII SubstIndex( getSema(), LeftFold ? I : *NumExpansions - I - 1); @@ -12097,13 +12106,13 @@ E->getBeginLoc(), LeftFold ? Result.get() : Out.get(), E->getOperator(), E->getEllipsisLoc(), LeftFold ? Out.get() : Result.get(), E->getEndLoc(), - OrigNumExpansions); + OrigNumExpansions, E->overloadCandsBegin(), E->overloadCandsEnd()); } else if (Result.isUsable()) { // We've got down to a single element; build a binary operator. Result = getDerived().RebuildBinaryOperator( E->getEllipsisLoc(), E->getOperator(), LeftFold ? Result.get() : Out.get(), - LeftFold ? Out.get() : Result.get()); + LeftFold ? Out.get() : Result.get(), &OverloadCands); } else Result = Out; @@ -12122,7 +12131,8 @@ Result = getDerived().RebuildCXXFoldExpr( E->getBeginLoc(), Result.get(), E->getOperator(), E->getEllipsisLoc(), - Out.get(), E->getEndLoc(), OrigNumExpansions); + Out.get(), E->getEndLoc(), OrigNumExpansions, E->overloadCandsBegin(), + E->overloadCandsEnd()); if (Result.isInvalid()) return true; } Index: clang/lib/Serialization/ASTReaderStmt.cpp =================================================================== --- clang/lib/Serialization/ASTReaderStmt.cpp +++ clang/lib/Serialization/ASTReaderStmt.cpp @@ -1883,6 +1883,7 @@ void ASTStmtReader::VisitCXXFoldExpr(CXXFoldExpr *E) { VisitExpr(E); + unsigned NumOverloadCands = Record.readInt(); E->LParenLoc = ReadSourceLocation(); E->EllipsisLoc = ReadSourceLocation(); E->RParenLoc = ReadSourceLocation(); @@ -1890,6 +1891,13 @@ E->SubExprs[0] = Record.readSubExpr(); E->SubExprs[1] = Record.readSubExpr(); E->Opcode = (BinaryOperatorKind)Record.readInt(); + + DeclAccessPair *OverloadCands = E->getTrailingObjects(); + for (unsigned I = 0; I != NumOverloadCands; ++I) { + auto *D = ReadDeclAs(); + auto AS = (AccessSpecifier)Record.readInt(); + OverloadCands[I].set(D, AS); + } } void ASTStmtReader::VisitOpaqueValueExpr(OpaqueValueExpr *E) { @@ -3411,7 +3419,8 @@ break; case EXPR_CXX_FOLD: - S = new (Context) CXXFoldExpr(Empty); + S = CXXFoldExpr::CreateEmpty(Context, + Record[ASTStmtReader::NumExprFields]); break; case EXPR_OPAQUE_VALUE: Index: clang/lib/Serialization/ASTWriterStmt.cpp =================================================================== --- clang/lib/Serialization/ASTWriterStmt.cpp +++ clang/lib/Serialization/ASTWriterStmt.cpp @@ -1817,6 +1817,7 @@ void ASTStmtWriter::VisitCXXFoldExpr(CXXFoldExpr *E) { VisitExpr(E); + Record.push_back(E->getNumOverloadCands()); Record.AddSourceLocation(E->LParenLoc); Record.AddSourceLocation(E->EllipsisLoc); Record.AddSourceLocation(E->RParenLoc); @@ -1824,6 +1825,12 @@ Record.AddStmt(E->SubExprs[0]); Record.AddStmt(E->SubExprs[1]); Record.push_back(E->Opcode); + for (UnresolvedSetIterator I = E->overloadCandsBegin(), + End = E->overloadCandsEnd(); + I != End; ++I) { + Record.AddDeclRef(I.getDecl()); + Record.push_back(I.getAccess()); + } Code = serialization::EXPR_CXX_FOLD; } Index: clang/test/SemaTemplate/cxx1z-fold-expressions.cpp =================================================================== --- clang/test/SemaTemplate/cxx1z-fold-expressions.cpp +++ clang/test/SemaTemplate/cxx1z-fold-expressions.cpp @@ -102,3 +102,49 @@ Sum<1>::type<1, 2> x; // expected-note {{instantiation of}} } + +namespace N { + + struct A { int i; }; + struct B { int i; }; + + constexpr B operator+(const B& a, const B& b) { return { a.i + b.i }; } + +} + +struct C { int i; }; + +constexpr C operator+(const C& a, const C& b) { return { a.i + b.i }; } +constexpr N::A operator+(const N::A& a, const N::A& b) { return { a.i + b.i }; } + +template constexpr auto custom_fold(T1 t1, T2 ...t2) { + return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2); +} + +static_assert(custom_fold(N::A{1}, N::A{2}, N::A{3}, N::A{4}, N::A{5}).i == 58); +static_assert(custom_fold(N::B{1}, N::B{2}, N::B{3}, N::B{4}, N::B{5}).i == 58); +static_assert(custom_fold(C{1}, C{2}, C{3}, C{4}, C{5}).i == 58); + +template constexpr auto func_fold( + decltype((T{ I2 } + ...) + (... + T{ I2 }) + (T{ I2 } + ... + T{ I1 }) + (T{ I1 } + ... + T{ I2 })) t) { + return t.i; +} + +static_assert(func_fold(N::A{ 42 }) == 42); +static_assert(func_fold(N::B{ 42 }) == 42); +static_assert(func_fold(C{ 42 }) == 42); + +struct D { int i; }; + +namespace N { + + constexpr D operator+(const D& a, const D& b) { return { a.i + b.i }; } + +} + +template constexpr auto custom_fold_using(T1 t1, T2 ...t2) { + using N::operator+; + return (t2 + ...) + (... + t2) + (t2 + ... + t1) + (t1 + ... + t2); +} + +static_assert(custom_fold_using(D{1}, D{2}, D{3}, D{4}, D{5}).i == 58);