Index: clang/lib/Sema/SemaStmt.cpp =================================================================== --- clang/lib/Sema/SemaStmt.cpp +++ clang/lib/Sema/SemaStmt.cpp @@ -559,17 +559,6 @@ StmtResult Sema::BuildAttributedStmt(SourceLocation AttrsLoc, ArrayRef Attrs, Stmt *SubStmt) { - // FIXME: this code should move when a planned refactoring around statement - // attributes lands. - for (const auto *A : Attrs) { - if (A->getKind() == attr::MustTail) { - if (!checkAndRewriteMustTailAttr(SubStmt, *A)) { - return SubStmt; - } - setFunctionHasMustTail(); - } - } - return AttributedStmt::Create(Context, AttrsLoc, Attrs, SubStmt); } Index: clang/lib/Sema/SemaStmtAttr.cpp =================================================================== --- clang/lib/Sema/SemaStmtAttr.cpp +++ clang/lib/Sema/SemaStmtAttr.cpp @@ -211,7 +211,11 @@ static Attr *handleMustTailAttr(Sema &S, Stmt *St, const ParsedAttr &A, SourceRange Range) { - // Validation is in Sema::ActOnAttributedStmt(). + MustTailAttr TmpAttr(S.Context, A); + if (!S.checkAndRewriteMustTailAttr(St, TmpAttr)) + return nullptr; + S.setFunctionHasMustTail(); + return ::new (S.Context) MustTailAttr(S.Context, A); } Index: clang/lib/Sema/SemaTemplateInstantiate.cpp =================================================================== --- clang/lib/Sema/SemaTemplateInstantiate.cpp +++ clang/lib/Sema/SemaTemplateInstantiate.cpp @@ -1074,7 +1074,10 @@ NamedDecl *FirstQualifierInScope = nullptr, bool AllowInjectedClassName = false); - const LoopHintAttr *TransformLoopHintAttr(const LoopHintAttr *LH); + const LoopHintAttr *TransformLoopHintAttr(const LoopHintAttr *LH, + Stmt *TransformedStmt); + const MustTailAttr *TransformMustTailAttr(const MustTailAttr *MTA, + Stmt *TransformedStmt); ExprResult TransformPredefinedExpr(PredefinedExpr *E); ExprResult TransformDeclRefExpr(DeclRefExpr *E); @@ -1485,7 +1488,7 @@ } const LoopHintAttr * -TemplateInstantiator::TransformLoopHintAttr(const LoopHintAttr *LH) { +TemplateInstantiator::TransformLoopHintAttr(const LoopHintAttr *LH, Stmt *) { Expr *TransformedExpr = getDerived().TransformExpr(LH->getValue()).get(); if (TransformedExpr == LH->getValue()) @@ -1501,6 +1504,18 @@ LH->getState(), TransformedExpr, *LH); } +const MustTailAttr * +TemplateInstantiator::TransformMustTailAttr(const MustTailAttr *MTA, + Stmt *TransformedStmt) { + // The attribute itself does not have arguments that need to be transformed, + // but the transformed return statement may require additional diagnostic + // checking and marking. + if (!getSema().checkAndRewriteMustTailAttr(TransformedStmt, *MTA)) + return nullptr; + getSema().setFunctionHasMustTail(); + return MTA; +} + ExprResult TemplateInstantiator::transformNonTypeTemplateParmRef( NonTypeTemplateParmDecl *parm, SourceLocation loc, Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -379,8 +379,14 @@ /// of attribute. Subclasses may override this function to transform /// attributed statements using some other mechanism. /// + /// The \c TransformedStmt parameter points to the substituted statement and + /// is non-const explicitly so that transformations of the attribute that + /// need to set state on the statement they apply to can do so if needed. + /// Note that \c TransformedStmt will be null when transforming type + /// attributes. + /// /// \returns the transformed attribute - const Attr *TransformAttr(const Attr *S); + const Attr *TransformAttr(const Attr *S, Stmt *TransformedStmt); /// Transform the specified attribute. /// @@ -388,10 +394,12 @@ /// spelling to transform expressions stored within the attribute. /// /// \returns the transformed attribute. -#define ATTR(X) -#define PRAGMA_SPELLING_ATTR(X) \ - const X##Attr *Transform##X##Attr(const X##Attr *R) { return R; } +#define ATTR(X) \ + const X##Attr *Transform##X##Attr(const X##Attr *R, Stmt *TransformedStmt) { \ + return R; \ + } #include "clang/Basic/AttrList.inc" +#undef ATTR /// Transform the given expression. /// @@ -6745,7 +6753,8 @@ // oldAttr can be null if we started with a QualType rather than a TypeLoc. const Attr *oldAttr = TL.getAttr(); - const Attr *newAttr = oldAttr ? getDerived().TransformAttr(oldAttr) : nullptr; + const Attr *newAttr = + oldAttr ? getDerived().TransformAttr(oldAttr, nullptr) : nullptr; if (oldAttr && !newAttr) return QualType(); @@ -7298,19 +7307,20 @@ } template -const Attr *TreeTransform::TransformAttr(const Attr *R) { - if (!R) - return R; +const Attr *TreeTransform::TransformAttr(const Attr *A, + Stmt *TransformedStmt) { + if (!A) + return A; - switch (R->getKind()) { + switch (A->getKind()) { // Transform attributes with a pragma spelling by calling TransformXXXAttr. -#define ATTR(X) -#define PRAGMA_SPELLING_ATTR(X) \ +#define ATTR(X) \ case attr::X: \ - return getDerived().Transform##X##Attr(cast(R)); + return getDerived().Transform##X##Attr(cast(A), TransformedStmt); #include "clang/Basic/AttrList.inc" +#undef ATTR default: - return R; + return A; } } @@ -7318,21 +7328,27 @@ StmtResult TreeTransform::TransformAttributedStmt(AttributedStmt *S, StmtDiscardKind SDK) { + // Transform the attributed statement first so that we can pass the + // transformed version in to the attribute handler. This is necessary because + // attributes that need to perform semantic checking are more likely to need + // the transformed statement than statement semantic checking is likely to + // need the transformed attributes. In such a case, the TransformFooAttr() + // function can mutate the non-const transformed statement that it is passed. + StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK); + if (SubStmt.isInvalid()) + return StmtError(); + bool AttrsChanged = false; SmallVector Attrs; // Visit attributes and keep track if any are transformed. for (const auto *I : S->getAttrs()) { - const Attr *R = getDerived().TransformAttr(I); + const Attr *R = getDerived().TransformAttr(I, SubStmt.get()); AttrsChanged |= (I != R); if (R) Attrs.push_back(R); } - StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK); - if (SubStmt.isInvalid()) - return StmtError(); - if (SubStmt.get() == S->getSubStmt() && !AttrsChanged) return S;