diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -49,6 +49,7 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); }; */ template @@ -222,6 +223,14 @@ }); } + void Visit(const TransformClause *C) { + getNodeDelegate().AddChild([=] { + getNodeDelegate().Visit(C); + for (const auto *S : C->children()) + Visit(S); + }); + } + void Visit(const ast_type_traits::DynTypedNode &N) { // FIXME: Improve this with a switch or a visitor pattern. if (const auto *D = N.get()) @@ -238,6 +247,8 @@ Visit(C); else if (const auto *T = N.get()) Visit(*T); + else if (const auto *C = N.get()) + Visit(C); } void dumpDeclContext(const DeclContext *DC) { @@ -620,6 +631,12 @@ Visit(C); } + void + VisitTransformExecutableDirective(const TransformExecutableDirective *Node) { + for (const auto *C : Node->clauses()) + Visit(C); + } + void VisitInitListExpr(const InitListExpr *ILE) { if (auto *Filler = ILE->getArrayFiller()) { Visit(Filler, "array_filler"); diff --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h --- a/clang/include/clang/AST/ASTTypeTraits.h +++ b/clang/include/clang/AST/ASTTypeTraits.h @@ -17,6 +17,7 @@ #include "clang/AST/ASTFwd.h" #include "clang/AST/NestedNameSpecifier.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/LLVM.h" @@ -68,6 +69,7 @@ static ASTNodeKind getFromNode(const Stmt &S); static ASTNodeKind getFromNode(const Type &T); static ASTNodeKind getFromNode(const OMPClause &C); + static ASTNodeKind getFromNode(const TransformClause &C); /// \} /// Returns \c true if \c this and \c Other represent the same kind. @@ -149,6 +151,9 @@ NKI_OMPClause, #define OPENMP_CLAUSE(TextualSpelling, Class) NKI_##Class, #include "clang/Basic/OpenMPKinds.def" + NKI_TransformClause, +#define TRANSFORM_CLAUSE(Keyword, Name) NKI_##Name##Clause, +#include "clang/AST/TransformClauseKinds.def" NKI_NumberOfKinds }; @@ -205,6 +210,8 @@ #include "clang/AST/TypeNodes.inc" #define OPENMP_CLAUSE(TextualSpelling, Class) KIND_TO_KIND_ID(Class) #include "clang/Basic/OpenMPKinds.def" +#define TRANSFORM_CLAUSE(Keyword, Name) KIND_TO_KIND_ID(Name##Clause) +#include "clang/AST/TransformClauseKinds.def" #undef KIND_TO_KIND_ID inline raw_ostream &operator<<(raw_ostream &OS, ASTNodeKind K) { diff --git a/clang/include/clang/AST/JSONNodeDumper.h b/clang/include/clang/AST/JSONNodeDumper.h --- a/clang/include/clang/AST/JSONNodeDumper.h +++ b/clang/include/clang/AST/JSONNodeDumper.h @@ -201,6 +201,8 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); + void Visit(const Transform *T); void VisitTypedefType(const TypedefType *TT); void VisitFunctionType(const FunctionType *T); diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -33,6 +33,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TemplateName.h" #include "clang/AST/Type.h" @@ -536,6 +537,11 @@ bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); + bool TraverseTransformClause(TransformClause *C); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + bool Visit##Name##Clause(Name##Clause *C); +#include "clang/AST/TransformClauseKinds.def" + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue); bool PostVisitStmt(Stmt *S); }; @@ -2703,6 +2709,11 @@ // Traverse OpenCL: AsType, Convert. DEF_TRAVERSE_STMT(AsTypeExpr, {}) +DEF_TRAVERSE_STMT(TransformExecutableDirective, { + for (auto *C : S->clauses()) + TRY_TO(TraverseTransformClause(C)); +}) + // OpenMP directives. template bool RecursiveASTVisitor::TraverseOMPExecutableDirective( @@ -2877,6 +2888,22 @@ DEF_TRAVERSE_STMT(OMPTargetTeamsDistributeSimdDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +template +bool RecursiveASTVisitor::TraverseTransformClause(TransformClause *C) { + if (!C) + return true; + switch (C->getKind()) { + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Cannot process unknown clause"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + TRY_TO(Visit##Name##Clause(static_cast(C))); \ + break; +#include "clang/AST/TransformClauseKinds.def" + } + return true; +} + // OpenMP clauses. template bool RecursiveASTVisitor::TraverseOMPClause(OMPClause *C) { @@ -3374,6 +3401,29 @@ return true; } +template +bool RecursiveASTVisitor::VisitFullClause(FullClause *C) { + return true; +} + +template +bool RecursiveASTVisitor::VisitPartialClause(PartialClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + +template +bool RecursiveASTVisitor::VisitWidthClause(WidthClause *C) { + TRY_TO(TraverseStmt(C->getWidth())); + return true; +} + +template +bool RecursiveASTVisitor::VisitFactorClause(FactorClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + template bool RecursiveASTVisitor::VisitOMPNontemporalClause( OMPNontemporalClause *C) { diff --git a/clang/include/clang/AST/StmtTransform.h b/clang/include/clang/AST/StmtTransform.h --- a/clang/include/clang/AST/StmtTransform.h +++ b/clang/include/clang/AST/StmtTransform.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_AST_STMTTRANSFROM_H #define LLVM_CLANG_AST_STMTTRANSFROM_H +#include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/Basic/Transform.h" #include "llvm/Support/raw_ostream.h" @@ -34,16 +35,320 @@ static Kind getClauseKind(Transform::Kind TransformKind, llvm::StringRef Str); static llvm::StringRef getClauseKeyword(TransformClause::Kind ClauseKind); - // TODO: implement +private: + Kind ClauseKind; + SourceRange LocRange; + +protected: + TransformClause(Kind K, SourceRange Range) : ClauseKind(K), LocRange(Range) {} + TransformClause(Kind K) : ClauseKind(K) {} + +public: + Kind getKind() const { return ClauseKind; } + + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setRange(SourceRange L) { LocRange = L; } + void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + + using child_iterator = Stmt::child_iterator; + using const_child_iterator = Stmt::const_child_iterator; + using child_range = Stmt::child_range; + using const_child_range = Stmt::const_child_range; + + child_range children(); + const_child_range children() const { + auto Children = const_cast(this)->children(); + return const_child_range(Children.begin(), Children.end()); + } + + static llvm::StringRef getClauseName(Kind K); + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class TransformClauseImpl : public TransformClause { +protected: + TransformClauseImpl(Kind ClauseKind, SourceRange Range) + : TransformClause(ClauseKind, Range) {} + explicit TransformClauseImpl(Kind ClauseKind) : TransformClause(ClauseKind) {} + +public: + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const { + llvm_unreachable("implement the print function"); + } + + child_range children() { + // By default, clauses have no children. + return child_range(child_iterator(), child_iterator()); + } +}; + +class FullClause final : public TransformClauseImpl { +private: + explicit FullClause(SourceRange Range) + : TransformClauseImpl(TransformClause::FullKind, Range) {} + FullClause() : TransformClauseImpl(TransformClause::FullKind) {} + +public: + static bool classof(const FullClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FullKind; + } + + static FullClause *create(ASTContext &Context, SourceRange Range) { + return new (Context) FullClause(Range); + } + static FullClause *createEmpty(ASTContext &Context) { + return new (Context) FullClause(); + } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class PartialClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + PartialClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::PartialKind, Range), + Factor(Factor) {} + PartialClause() : TransformClauseImpl(TransformClause::PartialKind) {} + +public: + static bool classof(const PartialClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == PartialKind; + } + + static PartialClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) PartialClause(Range, Factor); + } + static PartialClause *createEmpty(ASTContext &Context) { + return new (Context) PartialClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class WidthClause final : public TransformClauseImpl { +private: + Stmt *Width; + + WidthClause(SourceRange Range, Expr *Width) + : TransformClauseImpl(TransformClause::WidthKind, Range), Width(Width) {} + WidthClause() : TransformClauseImpl(TransformClause::WidthKind) {} + +public: + static bool classof(const WidthClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == WidthKind; + } + + static WidthClause *create(ASTContext &Context, SourceRange Range, + Expr *Width) { + return new (Context) WidthClause(Range, Width); + } + static WidthClause *createEmpty(ASTContext &Context) { + return new (Context) WidthClause(); + } + + child_range children() { return child_range(&Width, &Width + 1); } + + Expr *getWidth() const { return cast(Width); } + void setWidth(Expr *E) { Width = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; }; +class FactorClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + FactorClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::FactorKind, Range), + Factor(Factor) {} + FactorClause() : TransformClauseImpl(TransformClause::FactorKind) {} + +public: + static bool classof(const FactorClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FactorKind; + } + + static FactorClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) FactorClause(Range, Factor); + } + static FactorClause *createEmpty(ASTContext &Context) { + return new (Context) FactorClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +/// Visitor pattern for transform clauses. +template class Ptr, typename RetTy> +class TransformClauseVisitorBase { +protected: + ImplClass &getDerived() { return *static_cast(this); } + const ImplClass &getDerived() const { + return *static_cast(this); + } + +public: +#define PTR(CLASS) typename Ptr::type + +#define TRANSFORM_CLAUSE(Keyword, Name) \ + RetTy Visit##Name##Clause(PTR(Name##Clause) S) { \ + return getDerived().VisitTransformClause(S); \ + } +#include "clang/AST/TransformClauseKinds.def" + + RetTy Visit(PTR(TransformClause) C) { + switch (C->getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return getDerived().Visit##Name##Clause(static_cast(C)); +#include "clang/AST/TransformClauseKinds.def" + default: + llvm_unreachable("Unknown transform clause kind!"); + } + } + + // Base case + RetTy VisitTransformClause(PTR(TransformClause) C) { return RetTy(); } + +#undef PTR +}; + +template +class TransformClauseVisitor + : public TransformClauseVisitorBase {}; + +template +class ConstTransformClauseVisitor + : public TransformClauseVisitorBase {}; + /// Represents /// /// #pragma clang transform /// /// in the AST. -class TransformExecutableDirective final { - // TODO: implement +class TransformExecutableDirective final + : public Stmt, + private llvm::TrailingObjects { +public: + friend TransformClause; + friend TrailingObjects; + +private: + SourceRange LocRange; + Stmt *Associated = nullptr; + Transform::Kind TransKind = Transform::Kind::UnknownKind; + unsigned NumClauses; + +protected: + explicit TransformExecutableDirective(SourceRange LocRange, Stmt *Associated, + ArrayRef Clauses, + Transform::Kind TransKind) + : Stmt(Stmt::TransformExecutableDirectiveClass), LocRange(LocRange), + Associated(Associated), TransKind(TransKind), + NumClauses(Clauses.size()) { + setClauses(Clauses); + } + explicit TransformExecutableDirective(unsigned NumClauses) + : Stmt(Stmt::StmtClass::TransformExecutableDirectiveClass), + NumClauses(NumClauses) {} + + size_t numTrailingObjects(OverloadToken) const { + return NumClauses; + } + +public: + static bool classof(const TransformExecutableDirective *T) { return true; } + static bool classof(const Stmt *T) { + return T->getStmtClass() == Stmt::TransformExecutableDirectiveClass; + } + + static TransformExecutableDirective * + create(ASTContext &Ctx, SourceRange Range, Stmt *Associated, + ArrayRef Clauses, Transform::Kind TransKind); + static TransformExecutableDirective *createEmpty(ASTContext &Ctx, + unsigned NumClauses); + + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setRange(SourceRange Loc) { LocRange = Loc; } + void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + + Stmt *getAssociated() const { return Associated; } + void setAssociated(Stmt *S) { Associated = S; } + + Transform::Kind getTransformKind() const { return TransKind; } + + child_range children() { return child_range(&Associated, &Associated + 1); } + const_child_range children() const { + return const_child_range(&Associated, &Associated + 1); + } + + unsigned getNumClauses() const { return NumClauses; } + MutableArrayRef clauses() { + return llvm::makeMutableArrayRef(getTrailingObjects(), + NumClauses); + } + ArrayRef clauses() const { + return llvm::makeArrayRef(getTrailingObjects(), + NumClauses); + } + + void setClauses(llvm::ArrayRef List) { + assert(List.size() == NumClauses); + for (auto p : llvm::zip_first(List, clauses())) + std::get<1>(p) = std::get<0>(p); + } + + auto getClausesOfKind(TransformClause::Kind Kind) const { + return llvm::make_filter_range(clauses(), [Kind](TransformClause *Clause) { + return Clause->getKind() == Kind; + }); + } + + template auto getClausesOfKind() const { + return llvm::map_range( + llvm::make_filter_range(clauses(), + [](TransformClause *Clause) { + return isa(Clause); + }), + [](TransformClause *Clause) { return cast(Clause); }); + } + + template + SpecificClause *getFirstClauseOfKind() const { + auto Range = getClausesOfKind(); + if (Range.begin() == Range.end()) + return nullptr; + return *Range.begin(); + } }; const Stmt *getAssociatedLoop(const Stmt *S); diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h --- a/clang/include/clang/AST/TextNodeDumper.h +++ b/clang/include/clang/AST/TextNodeDumper.h @@ -172,6 +172,10 @@ void Visit(const OMPClause *C); + void VisitTransformExecutableDirective(const TransformExecutableDirective *S); + + void Visit(const TransformClause *C); + void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); diff --git a/clang/include/clang/Analysis/AnalysisTransform.h b/clang/include/clang/Analysis/AnalysisTransform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Analysis/AnalysisTransform.h @@ -0,0 +1,190 @@ +//===---- AnalysisTransform.h --------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Extract the transformation to apply from a #pragma clang transform AST node. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H +#define LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H + +#include "clang/AST/StmtTransform.h" +#include "clang/Basic/DiagnosticSema.h" +#include "clang/Basic/Transform.h" + +namespace clang { + +static bool isTemplateDependent(Expr *E) { + return E->isValueDependent() || E->isTypeDependent() || + E->isInstantiationDependent() || E->containsUnexpandedParameterPack(); +} + +/// Extract which transformation to apply from a TransformExecutableDirective +/// and its clauses. +template struct ExtractTransform { + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + ASTContext &ASTCtx; + const TransformExecutableDirective *Directive; + bool AnyError = false; + bool TemplateDependent = false; + ExtractTransform(ASTContext &ASTCtx, + const TransformExecutableDirective *Directive) + : ASTCtx(ASTCtx), Directive(Directive) { + assert(Directive); + } + + auto DiagError(SourceLocation Loc, unsigned DiagID) { + AnyError = true; + return getDerived().Diag(Loc, DiagID); + } + + template ClauseTy *assumeSingleClause() { + ClauseTy *Result = nullptr; + for (ClauseTy *C : Directive->getClausesOfKind()) { + if (Result) { + DiagError(C->getBeginLoc(), diag::err_sema_transform_clause_one_max) + << TransformClause::getClauseKeyword(C->getKind()) << C->getRange(); + break; + } + + Result = C; + } + + return Result; + } + + llvm::Optional evalIntArg(Expr *E, int MinVal) { + if (isTemplateDependent(E)) { + TemplateDependent = true; + return None; + } + Expr::EvalResult Res; + E->EvaluateAsInt(Res, ASTCtx); + if (!Res.Val.isInt()) { + DiagError(E->getExprLoc(), + diag::err_sema_transform_clause_arg_expect_int); + return None; + } + llvm::APSInt Val = Res.Val.getInt(); + int64_t Int = Val.getSExtValue(); + if (Int < MinVal) { + DiagError(E->getExprLoc(), diag::err_sema_transform_clause_arg_min_val) + << MinVal << SourceRange(E->getBeginLoc(), E->getEndLoc()); + return None; + } + return Int; + } + + void allowedClauses(ArrayRef ClauseKinds) { +#ifndef NDEBUG + for (TransformClause *C : Directive->clauses()) { + assert(llvm::find(ClauseKinds, C->getKind()) != ClauseKinds.end() && + "Parser must have rejected unknown clause"); + } +#endif + } + + static std::unique_ptr wrap(Transform *Trans) { + return std::unique_ptr(Trans); + } + + std::unique_ptr createTransform() { + Transform::Kind Kind = Directive->getTransformKind(); + SourceRange Loc = Directive->getRange(); + + switch (Kind) { + case Transform::LoopUnrollKind: { + allowedClauses({TransformClause::FullKind, TransformClause::PartialKind}); + FullClause *Full = assumeSingleClause(); + PartialClause *Partial = assumeSingleClause(); + if (Full && Partial) + DiagError(Full->getBeginLoc(), + diag::err_sema_transform_unroll_full_or_partial); + + if (AnyError) + return nullptr; + + if (Full) { + return wrap(LoopUnrollTransform::createFull(Loc)); + } else if (Partial) { + llvm::Optional Factor = evalIntArg(Partial->getFactor(), 2); + if (AnyError || !Factor.hasValue()) + return nullptr; + return wrap(LoopUnrollTransform::createPartial(Loc, Factor.getValue())); + } + + return wrap(LoopUnrollTransform::createHeuristic(Loc)); + } + + case Transform::LoopUnrollAndJamKind: { + allowedClauses({TransformClause::PartialKind}); + PartialClause *Partial = assumeSingleClause(); + + if (AnyError) + return nullptr; + + if (Partial) { + llvm::Optional Factor = evalIntArg(Partial->getFactor(), 2); + if (AnyError || !Factor.hasValue()) + return nullptr; + return wrap( + LoopUnrollAndJamTransform::createPartial(Loc, Factor.getValue())); + } + + return wrap(LoopUnrollAndJamTransform::createHeuristic(Loc)); + } + + case clang::Transform::LoopDistributionKind: + allowedClauses({}); + return wrap(LoopDistributionTransform::create(Loc)); + + case clang::Transform::LoopVectorizationKind: { + allowedClauses({TransformClause::WidthKind}); + WidthClause *Width = assumeSingleClause(); + if (AnyError) + return nullptr; + + int64_t Simdlen = 0; + if (Width) { + llvm::Optional WidthInt = evalIntArg(Width->getWidth(), 2); + if (AnyError || !WidthInt.hasValue()) + return nullptr; + Simdlen = WidthInt.getValue(); + } + + return wrap(LoopVectorizationTransform::create(Loc, Simdlen)); + } + + case clang::Transform::LoopInterleavingKind: { + allowedClauses({TransformClause::FactorKind}); + FactorClause *Factor = assumeSingleClause(); + if (AnyError) + return nullptr; + + int64_t InterleaveFactor = 0; + if (Factor) { + llvm::Optional FactorInt = evalIntArg(Factor->getFactor(), 2); + if (AnyError || !FactorInt.hasValue()) + return nullptr; + InterleaveFactor = FactorInt.getValue(); + } + + return wrap(LoopInterleavingTransform::create(Loc, InterleaveFactor)); + } + default: + llvm_unreachable("unimplemented"); + } + } +}; + +} // namespace clang +#endif /* LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H */ diff --git a/clang/include/clang/Analysis/TransformedTree.h b/clang/include/clang/Analysis/TransformedTree.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Analysis/TransformedTree.h @@ -0,0 +1,809 @@ +//===--- TransformedTree.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Applies code transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H +#define LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H + +#include "clang/AST/OpenMPClause.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/AST/StmtOpenMP.h" +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Basic/DiagnosticSema.h" +#include "llvm/ADT/SmallVector.h" + +namespace clang { +template class TransformedTreeBuilder; + +struct DefaultExtractTransform : ExtractTransform { + DefaultExtractTransform(ASTContext &ASTCtx, + const TransformExecutableDirective *Directive) + : ExtractTransform(ASTCtx, Directive) {} + + // Ignore any diagnostic and its arguments. + struct DummyDiag { + template DummyDiag operator<<(const T &) const { return {}; } + }; + DummyDiag Diag(SourceLocation Loc, unsigned DiagID) { return {}; } +}; + +/// Represents an input of a code representation. +/// Current can reference the input code only by the AST node, but in the future +/// loops can also given identifiers to reference them. +class TransformInput { + const Stmt *StmtInput = nullptr; + Transform *PrecTrans = nullptr; + int FollowupIdx = -1; + + TransformInput(const Stmt *StmtInput, Transform *PrecTrans, int FollowupIdx) + : StmtInput(StmtInput), PrecTrans(PrecTrans), FollowupIdx(FollowupIdx) {} + +public: + TransformInput() {} + + static TransformInput createByStmt(const Stmt *StmtInput) { + assert(StmtInput); + return TransformInput(StmtInput, nullptr, -1); + } + + // In general, the same clang::Transform can be reused multiple times with + // different inputs, when referencing its followup using this constructor, the + // clang::Transform can only be used once per function to ensure that its + // followup can be uniquely identified. + static TransformInput createByFollowup(Transform *Transform, + int FollowupIdx) { + assert(Transform); + assert(0 <= FollowupIdx && FollowupIdx < Transform->getNumFollowups()); + return TransformInput(nullptr, Transform, FollowupIdx); + } + + bool isByStmt() const { return StmtInput; } + bool isByFollowup() const { return PrecTrans; } + + const Stmt *getStmtInput() const { return StmtInput; } + + Transform *getPrecTrans() const { return PrecTrans; } + int getFollowupIdx() const { return FollowupIdx; } +}; + +/// Represents a transformation together with the input loops. +/// in the future it will also identify the generated loop. +struct NodeTransform { + Transform *Trans = nullptr; + llvm::SmallVector Inputs; + + NodeTransform() {} + NodeTransform(Transform *Trans, TransformInput TopLevelInput) : Trans(Trans) { + assert(Trans->getNumInputs() >= 1); + + Inputs.resize(Trans->getNumInputs()); + setInput(0, TopLevelInput); + }; + + void setInput(int Idx, TransformInput Input) { Inputs[Idx] = Input; } +}; + +/// This class represents a loop in a loop nest to which transformations are +/// applied to. It is intended to be instantiated for it specific purpose, that +/// is SemaTransformedTree for semantic analysis (consistency warnings and +/// errors) and CGTransformedTree for emitting IR. +template class TransformedTree { + template friend class TransformedTreeBuilder; + using NodeTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + +protected: + /// Is this the root node of the loop hierarchy? + bool IsRoot = false; + + /// Does this node have a loop hint applied to it? + bool HasLoopHint = false; + + /// Nested loops. + llvm::SmallVector Subloops; + + /// Origin of this loop. + /// @{ + + /// If not the result of a transformation, this is the loop statement that + /// this node represents. + Stmt *Original; + + /// If the result of a transformation, this points to the primary node that + /// the transformation is applied to. BasedOn->Followups has to contain this + /// node. + Derived *BasedOn; + + /// If the result of a transformation, this is the followup role a defined by + /// the transformation applied to @p BasedOn. + int FollowupRole; + + /// If the result of a transformation, points to the node that was transformed + /// into this node. The predecessor's @p Successors must point to this node. + Derived *Predecessor; + /// @} + + /// Transformations applied to this loop. + /// @{ + + /// Points to the primary input this loop is transformed by (the one that + /// #pragma clang transform is applied to). + Derived *PrimaryInput = nullptr; + + /// If this is the primary transformation input, contains the transformation + /// that is applied to the loop nest. For non-primary inputs, it is nullptr. + /// To find out which transformation is applied to this loop, one must follow + /// the @p PrimaryInput. + Transform *TransformedBy = nullptr; + + /// If this is the primary transformation input, contains the followups as + /// defined by TransformedBy->getNumFollowups(). The @p BasedOn attribute of a + /// followup node must point back to this node. + llvm::SmallVector Followups; + + /// List of loops that inherits loop properties from this loop after a + /// transformations. For instance, if this loop is marked as 'executable in + /// parallel', depending on the transformation, successor loops will was well. + /// A successor's @p Predecessor field must point back to this node. The first + /// successor in the list is the primary successor: A #pragma clang transform + /// applied to the output of the transformation will be applied to the primary + /// successor. + llvm::SmallVector Successors; + + /// Input role of this loop as defined by the primary input's transformation. + int InputRole = -1; + /// @} + +protected: + TransformedTree(llvm::ArrayRef SubLoops, Derived *BasedOn, + clang::Stmt *Original, int FollowupRole, Derived *Predecessor) + : Subloops(SubLoops.begin(), SubLoops.end()), Original(Original), + BasedOn(BasedOn), FollowupRole(FollowupRole), Predecessor(Predecessor) { + assert(!BasedOn == (FollowupRole == -1) && + "Role must be defined if the result of a transformation"); + assert(!BasedOn == !Predecessor && + "Predecessor must be defined if the result of a transformation"); + assert(!Original || !BasedOn); + } + +public: + ArrayRef getSubLoops() const { return Subloops; } + + void getLatestSubLoops(SmallVectorImpl &Result) { + Result.reserve(Subloops.size()); + for (auto SubL : Subloops) + SubL->getLatestSuccessors(Result); + } + + Derived *getPrimaryInput() const { return PrimaryInput; } + Transform *getTransformedBy() const { return TransformedBy; } + + /// Return the transformation that generated this loop. Return nullptr if not + /// the result of any transformation, i.e. it is an original loop. + Transform *getSourceTransformation() const { + assert(!BasedOn == isOriginal() && + "Non-original loops must be based on some other loop"); + if (isOriginal()) + return nullptr; + + assert(BasedOn); + assert(BasedOn->isTransformationInput()); + Transform *Result = BasedOn->PrimaryInput->TransformedBy; + assert(Result && + "Non-original loops must have a generating transformation"); + return Result; + } + + Stmt *getOriginal() const { return Original; } + Stmt *getInheritedOriginal() const { + if (Original) + return Original; + if (Predecessor && Predecessor->getSuccessors()[0] == &getDerived()) + return Predecessor->getInheritedOriginal(); + return nullptr; + } + + Derived *getBasedOn() const { return BasedOn; } + + bool isRoot() const { return IsRoot; } + + bool hasLoopHint() const { return HasLoopHint; } + + void markLoopHint() { HasLoopHint = true; } + + ArrayRef getSuccessors() const { return Successors; } + + void getLatestSuccessors(SmallVectorImpl &Result) { + // If the loop is not being consumed, this is the latest successor. + if (!isTransformationInput()) { + Result.push_back(&getDerived()); + return; + } + + for (Derived *Succ : Successors) + Succ->getLatestSuccessors(Result); + } + + bool isOriginal() const { return Original; } + + bool isTransformationInput() const { + bool Result = InputRole >= 0; + assert(Result == (PrimaryInput != nullptr)); + return Result; + } + + bool isTransformationFollowup() const { + bool Result = FollowupRole >= 0; + assert(Result == (BasedOn != nullptr)); + return Result; + } + + bool isPrimaryInput() const { + bool Result = (InputRole == 0); + assert(Result == (PrimaryInput == this)); + return PrimaryInput == this; + } + + int getFollowupRole() const { return FollowupRole; } + + ArrayRef getFollowups() const { return Followups; } + + void applyTransformation(Transform *Trans, + llvm::ArrayRef Followups, + ArrayRef Successors) { + assert(!isTransformationInput()); + assert(llvm::find(Successors, nullptr) == Successors.end()); + assert(Trans->getNumFollowups() == Followups.size()); + + this->TransformedBy = Trans; + this->Followups.insert(this->Followups.end(), Followups.begin(), + Followups.end()); + this->Successors.assign(Successors.begin(), Successors.end()); + this->PrimaryInput = &getDerived(); + this->InputRole = 0; // for primary + +#ifndef NDEBUG + assert(isTransformationInput() && isPrimaryInput()); + for (NodeTy *F : Followups) { + assert(F->BasedOn == &getDerived()); + } + for (NodeTy *S : Successors) { + assert(S->Predecessor == &getDerived()); + } +#endif + } + + void applySuccessors(Derived *PrimaryInput, int InputRole, + ArrayRef Successors) { + assert(!isTransformationInput()); + assert(InputRole > 0); + assert(llvm::find(Successors, nullptr) == Successors.end()); + + this->PrimaryInput = PrimaryInput; + this->Successors.assign(Successors.begin(), Successors.end()); + this->InputRole = InputRole; + +#ifndef NDEBUG + assert(isTransformationInput() && !isPrimaryInput()); + for (NodeTy *S : Followups) { + assert(S->BasedOn == &getDerived()); + } +#endif + } +}; + +/// Constructs a loop nest from source and applies transformations on it. +template class TransformedTreeBuilder { + using BuilderTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + ASTContext &ASTCtx; + const LangOptions &LangOpts; + llvm::SmallVectorImpl &AllNodes; + llvm::SmallVectorImpl &AllTransforms; + +private: + /// Build the original loop nest hierarchy from the AST. + void buildPhysicalLoopTree(Stmt *S, SmallVectorImpl &SubLoops, + llvm::DenseMap &StmtToTree, + bool MarkLoopHint = false) { + if (!S) + return; + + Stmt *Body; + switch (S->getStmtClass()) { + case Stmt::ForStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::WhileStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::DoStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CXXForRangeStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CapturedStmtClass: + buildPhysicalLoopTree(cast(S)->getCapturedStmt(), SubLoops, + StmtToTree, MarkLoopHint); + return; + case Expr::LambdaExprClass: + // Call to getBody materializes its body, children() (which is called in + // the default case) does not. + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + case Expr::BlockExprClass: + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + case Expr::AttributedStmtClass: + buildPhysicalLoopTree( + cast(S)->getSubStmt(), SubLoops, StmtToTree, + llvm::any_of(cast(S)->getAttrs(), + [](const Attr *A) { return isa(A); })); + return; + default: + if (auto *O = dyn_cast(S)) { + if (!O->hasAssociatedStmt()) + return; + MarkLoopHint = true; + Stmt *Associated = O->getAssociatedStmt(); + buildPhysicalLoopTree(Associated, SubLoops, StmtToTree, true); + return; + } + + for (Stmt *Child : S->children()) + buildPhysicalLoopTree(Child, SubLoops, StmtToTree); + + return; + } + + SmallVector SubSubLoops; + buildPhysicalLoopTree(Body, SubSubLoops, StmtToTree); + + NodeTy *L = getDerived().createPhysical(SubSubLoops, S); + if (MarkLoopHint) + L->markLoopHint(); + + SubLoops.push_back(L); + assert(StmtToTree.count(S) == 0); + StmtToTree[S] = L; + + getDerived().applyOriginal(L); + } + + /// Collect all loop transformations in the function's AST. + class CollectTransformationsVisitor + : public RecursiveASTVisitor { + Derived &Builder; + llvm::DenseMap &StmtToTree; + + public: + CollectTransformationsVisitor(Derived &Builder, + llvm::DenseMap &StmtToTree) + : Builder(Builder), StmtToTree(StmtToTree) {} + + /// Transformations collected so far. + llvm::SmallVector Transforms; + + bool shouldTraversePostOrder() const { return true; } + + bool + VisitTransformExecutableDirective(const TransformExecutableDirective *S) { + // TODO: Check if AttributeStmt with LoopHint or OpenMP is also also + // present and error-out if it is. + + // This ExtractTransform does not emit any diagnostics. Diagnostics should + // have been emitted in Sema::ActOnTransformExecutableDirective. + DefaultExtractTransform ExtractTransform(Builder.ASTCtx, S); + std::unique_ptr Trans = ExtractTransform.createTransform(); + + // We might not get a transform in non-instantiated templates or + // inconsistent clauses. + if (Trans) { + const Stmt *TheLoop = getAssociatedLoop(S->getAssociated()); + Transforms.emplace_back(Trans.get(), + TransformInput::createByStmt(TheLoop)); + Builder.AllTransforms.push_back(Trans.release()); + } + + return true; + } + }; + + /// Applies collected transformations to the loop nest representation. + struct TransformApplicator { + Derived &Builder; + llvm::DenseMap> + TransByStmt; + llvm::DenseMap> + TransformByFollowup; + + TransformApplicator(Derived &Builder) : Builder(Builder) {} + + void addNodeTransform(NodeTransform *NT) { + TransformInput &TopLevelInput = NT->Inputs[0]; + if (TopLevelInput.isByStmt()) { + TransByStmt[TopLevelInput.getStmtInput()].push_back(NT); + } else if (TopLevelInput.isByFollowup()) { + TransformByFollowup[TopLevelInput.getPrecTrans()].push_back(NT); + } else + llvm_unreachable("Transformation must apply to something"); + } + + void checkStageOrder(ArrayRef PrevLoops, Transform *NewTrans) { + for (NodeTy *PrevLoop : PrevLoops) { + + // Cannot combine legacy disable pragmas (e.g. #pragma clang loop + // unroll(disable)) and new transformations (#pragma clang transform). + if (PrevLoop->hasLoopHint()) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::err_sema_transform_legacy_mix); + return; + } + + Transform *PrevSourceTrans = PrevLoop->getSourceTransformation(); + if (!PrevSourceTrans) + continue; + + int PrevStage = PrevSourceTrans->getLoopPipelineStage(); + int NewStage = NewTrans->getLoopPipelineStage(); + if (PrevStage >= 0 && NewStage >= 0 && PrevStage > NewStage) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::warn_sema_transform_pass_order); + + // At most one warning per transformation. + return; + } + } + } + + NodeTy *applyTransform(Transform *Trans, NodeTy *MainLoop) { + switch (Trans->getKind()) { + case Transform::Kind::LoopUnrollKind: + return applyUnrolling(cast(Trans), MainLoop); + case Transform::Kind::LoopUnrollAndJamKind: + return applyUnrollAndJam(cast(Trans), + MainLoop); + case Transform::Kind::LoopDistributionKind: + return applyDistribution(cast(Trans), + MainLoop); + case Transform::Kind::LoopVectorizationKind: + return applyVectorize(cast(Trans), + MainLoop); + case Transform::Kind::LoopInterleavingKind: + return applyInterleave(cast(Trans), + MainLoop); + default: + llvm_unreachable("unimplemented transformation"); + } + } + + void inheritLoopAttributes(NodeTy *Dst, NodeTy *Src, bool IsAll, + bool IsSuccessor) { + Builder.inheritLoopAttributes(Dst, Src, IsAll, IsSuccessor); + } + + NodeTy *applyUnrolling(LoopUnrollTransform *Trans, NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *Successor = nullptr; + if (Trans->isFull()) { + // Full unrolling has no followup-loop. + MainLoop->applyTransformation(Trans, {}, {}); + } else { + NodeTy *All = + Builder.createFollowup(MainLoop->Subloops, MainLoop, + LoopUnrollTransform::FollowupAll, nullptr); + NodeTy *Unrolled = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopUnrollTransform::FollowupUnrolled, + MainLoop); + NodeTy *Remainder = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopUnrollTransform::FollowupRemainder, MainLoop); + Successor = Unrolled; + inheritLoopAttributes(All, MainLoop, true, All == Successor); + MainLoop->applyTransformation(Trans, {All, Unrolled, Remainder}, + Successor); + } + + Builder.applyUnroll(Trans, MainLoop); + return Successor; + } + + NodeTy *applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + NodeTy *MainLoop) { + // Search for the innermost loop that is being jammed. + NodeTy *Cur = MainLoop; + NodeTy *Inner = nullptr; + SmallVector LatestInner; + Cur->getLatestSubLoops(LatestInner); + if (LatestInner.size() == 1) { + Inner = LatestInner[0]; + } else { + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (!Inner) { + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (Inner->Subloops.size() != 0) { + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_not_innermost); + return nullptr; + } + + checkStageOrder({MainLoop, Inner}, Trans); + + NodeTy *PrimarySuccessor = nullptr; + NodeTy *TransformedAll = Builder.createFollowup( + {}, MainLoop, LoopUnrollAndJamTransform::FollowupAll, nullptr); + inheritLoopAttributes(TransformedAll, MainLoop, true, false); + + if (Trans->isPartial()) { + NodeTy *UnrolledOuter = Builder.createFollowup( + {Inner}, MainLoop, LoopUnrollAndJamTransform::FollowupOuter, + MainLoop); + inheritLoopAttributes(UnrolledOuter, MainLoop, false, true); + + NodeTy *TransformedInner = Builder.createFollowup( + Inner->Subloops, MainLoop, LoopUnrollAndJamTransform::FollowupInner, + Inner); + inheritLoopAttributes(TransformedInner, Inner, false, false); + + MainLoop->applyTransformation( + Trans, {TransformedAll, UnrolledOuter, TransformedInner}, + UnrolledOuter); + Inner->applySuccessors(MainLoop, LoopUnrollAndJamTransform::InputInner, + TransformedInner); + PrimarySuccessor = UnrolledOuter; + } else { + MainLoop->applyTransformation(Trans, {TransformedAll}, {}); + Inner->applySuccessors(MainLoop, LoopUnrollAndJamTransform::InputInner, + {}); + } + + Builder.applyUnrollAndJam(Trans, MainLoop, Inner); + return PrimarySuccessor; + } + + NodeTy *applyDistribution(LoopDistributionTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopDistributionTransform::FollowupAll, + nullptr); + + inheritLoopAttributes(All, MainLoop, true, false); + MainLoop->applyTransformation(Trans, {All}, {}); + + Builder.applyDistribution(Trans, MainLoop); + return nullptr; + } + + NodeTy *applyVectorize(LoopVectorizationTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopVectorizationTransform::FollowupAll, + nullptr); + NodeTy *Vectorized = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationTransform::FollowupVectorized, MainLoop); + NodeTy *Epilogue = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationTransform::FollowupEpilogue, MainLoop); + + inheritLoopAttributes(All, MainLoop, true, false); + MainLoop->applyTransformation(Trans, {All, Vectorized, Epilogue}, + Vectorized); + Builder.applyVectorization(Trans, MainLoop); + return Vectorized; + } + + NodeTy *applyInterleave(LoopInterleavingTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopInterleavingTransform::FollowupAll, + nullptr); + NodeTy *Interleaved = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopInterleavingTransform::FollowupInterleaved, MainLoop); + NodeTy *Epilogue = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopInterleavingTransform::FollowupEpilogue, MainLoop); + + inheritLoopAttributes(All, MainLoop, true, false); + MainLoop->applyTransformation(Trans, {All, Interleaved, Epilogue}, + Interleaved); + Builder.applyInterleaving(Trans, MainLoop); + return Interleaved; + } + + void traverseSubloops(NodeTy *L) { + // TODO: Instead of recursively traversing the entire subtree, in case we + // are re-traversing after a transformation, only traverse the followups + // of that transformation. + SmallVector Latest; + for (NodeTy *SubL : L->getSubLoops()) { + Latest.clear(); + SubL->getLatestSuccessors(Latest); + for (NodeTy *SubL : Latest) + traverse(SubL); + } + } + + bool applyTransform(NodeTy *L) { + if (L->isRoot()) + return false; + + // Look for transformations that apply syntactically to this loop. + Stmt *OrigStmt = L->getInheritedOriginal(); + auto TransformsOnStmt = TransByStmt.find(OrigStmt); + if (TransformsOnStmt != TransByStmt.end()) { + auto &List = TransformsOnStmt->second; + if (!List.empty()) { + NodeTransform *Trans = List.front(); + applyTransform(Trans->Trans, L); + List.erase(List.begin()); + + return true; + } + } + + // Look for transformations that are chained to one of the followups. + auto SourceTrans = L->getSourceTransformation(); + if (!SourceTrans) + return false; + auto Chained = TransformByFollowup.find(SourceTrans); + if (Chained == TransformByFollowup.end()) + return false; + int LIdx = L->getFollowupRole(); + auto &List = Chained->second; + for (auto It = List.begin(), E = List.end(); It != E; ++It) { + NodeTransform *Item = *It; + int FollowupIdx = Item->Inputs[0].getFollowupIdx(); + if (LIdx != FollowupIdx) + continue; + + applyTransform(Item->Trans, L); + List.erase(It); + return true; + } + + return false; + } + + void traverse(NodeTy *N) { + SmallVector Latest; + N->getLatestSuccessors(Latest); + for (NodeTy *L : Latest) { + traverseSubloops(L); + if (applyTransform(L)) { + // Apply transformations on nested followups. + traverse(L); + } + } + } + }; + +protected: + TransformedTreeBuilder(ASTContext &ASTCtx, const LangOptions &LangOpts, + llvm::SmallVectorImpl &AllNodes, + llvm::SmallVectorImpl &AllTransforms) + : ASTCtx(ASTCtx), LangOpts(LangOpts), AllNodes(AllNodes), + AllTransforms(AllTransforms) {} + + NodeTy *createRoot(llvm::ArrayRef SubLoops) { + auto *Result = new NodeTy(SubLoops, nullptr, nullptr, -1, nullptr); + AllNodes.push_back(Result); + Result->IsRoot = true; + return Result; + } + + NodeTy *createPhysical(llvm::ArrayRef SubLoops, + clang::Stmt *Original) { + assert(Original); + auto *Result = new NodeTy(SubLoops, nullptr, Original, -1, nullptr); + AllNodes.push_back(Result); + return Result; + } + + NodeTy *createFollowup(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + int FollowupRole, NodeTy *Predecessor) { + auto *Result = + new NodeTy(SubLoops, BasedOn, nullptr, FollowupRole, Predecessor); + AllNodes.push_back(Result); + return Result; + } + +public: + NodeTy * + computeTransformedStructure(Stmt *Body, + llvm::DenseMap &StmtToTree) { + if (!Body) + return nullptr; + + // Create original tree. + SmallVector TopLevelLoops; + buildPhysicalLoopTree(Body, TopLevelLoops, StmtToTree); + NodeTy *Root = getDerived().createRoot(TopLevelLoops); + + // Collect all loop transformations. + CollectTransformationsVisitor Collector(getDerived(), StmtToTree); + Collector.TraverseStmt(Body); + auto &TransformList = Collector.Transforms; + + // Local function to apply every transformation that the predicate returns + // true to. This is to emulate LLVM's pass pipeline that would apply + // transformation in this order. + auto SelectiveApplicator = [this, &TransformList, Root](auto Pred) { + TransformApplicator OMPApplicator(getDerived()); + bool AnyActive = false; + for (NodeTransform &NT : TransformList) { + if (Pred(NT)) { + OMPApplicator.addNodeTransform(&NT); + AnyActive = true; + } + } + + // No traversal needed if no transformations to apply. + if (!AnyActive) + return; + + OMPApplicator.traverse(Root); + + // Report leftover transformations. + for (auto &P : OMPApplicator.TransByStmt) { + for (NodeTransform *NT : P.second) { + getDerived().Diag(NT->Trans->getBeginLoc(), + diag::err_sema_transform_missing_loop); + } + } + + // Remove applied transformations from list. + auto NewEnd = + std::remove_if(TransformList.begin(), TransformList.end(), Pred); + TransformList.erase(NewEnd, TransformList.end()); + }; + + // Apply all others. + SelectiveApplicator([](NodeTransform &NT) -> bool { return true; }); + assert(TransformList.size() == 0 && "Must apply all transformations"); + + getDerived().finalize(Root); + + return Root; + } +}; + +} // namespace clang +#endif /* LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H */ diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10249,4 +10249,28 @@ "function template with 'sycl_kernel' attribute must have a 'void' return type">, InGroup; +// Pragma transform support. +def err_sema_transform_expected_loop : Error< + "cannot find loop to transform">; +def err_sema_transform_unroll_full_or_partial : Error< + "the full and partial clauses are mutually exclusive">; +def err_sema_transform_unrollandjam_expect_nested_loop : Error< + "unroll-and-jam requires exactly one nested loop">; +def err_sema_transform_unrollandjam_not_innermost : Error< + "inner loop of unroll-and-jam is not the innermost">; +def err_sema_transform_missing_loop : Error< + "transformation did not find its loop to transform; it might have been consumed by another transformations">; +def err_sema_transform_clause_one_max : Error< + "the %0 clause can be specified at most once">; +def err_sema_transform_clause_arg_expect_int : Error< + "clause expects an integer argument">; +def err_sema_transform_clause_arg_min_val : Error< + "clause argument must me at least %0">; +def err_sema_transform_legacy_mix : Error< + "Cannot combine #pragma clang transform with other transformations">; + +def warn_sema_transform_pass_order : Warning< + "the LLVM pass structure currently is not able to apply the transformations in this order">, + InGroup; + } // end of sema component. diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -209,6 +209,9 @@ // OpenCL Extensions. def AsTypeExpr : StmtNode; +// Transform Directives. +def TransformExecutableDirective : StmtNode; + // OpenMP Directives. def OMPExecutableDirective : StmtNode; def OMPLoopDirective : StmtNode; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11793,6 +11793,8 @@ TransformClause *ActOnPartialClause(SourceRange Loc, Expr *Factor); TransformClause *ActOnWidthClause(SourceRange Loc, Expr *Width); TransformClause *ActOnFactorClause(SourceRange Loc, Expr *Factor); + + void HandleLoopTransformations(FunctionDecl *FD); }; /// RAII object that enters a new expression evaluation context. diff --git a/clang/include/clang/Sema/SemaTransform.h b/clang/include/clang/Sema/SemaTransform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Sema/SemaTransform.h @@ -0,0 +1,79 @@ +//===---- SemaTransform.h ------------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Semantic analysis for code transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_SEMATRANSFORM_H +#define LLVM_CLANG_SEMA_SEMATRANSFORM_H + +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Analysis/TransformedTree.h" +#include "clang/Sema/Sema.h" + +namespace clang { +class Sema; + +class SemaTransformedTree : public TransformedTree { + friend class TransformedTree; + using BaseTy = TransformedTree; + using NodeTy = SemaTransformedTree; + + BaseTy &getBase() { return *this; } + const BaseTy &getBase() const { return *this; } + +public: + SemaTransformedTree(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + clang::Stmt *Original, int FollowupRole, + NodeTy *Predecessor) + : TransformedTree(SubLoops, BasedOn, Original, FollowupRole, + Predecessor) {} +}; + +class SemaTransformedTreeBuilder + : public TransformedTreeBuilder { + using NodeTy = SemaTransformedTree; + + Sema &Sem; + +public: + SemaTransformedTreeBuilder(ASTContext &ASTCtx, const LangOptions &LangOpts, + llvm::SmallVectorImpl &AllNodes, + llvm::SmallVectorImpl &AllTransforms, + Sema &Sem) + : TransformedTreeBuilder(ASTCtx, LangOpts, AllNodes, AllTransforms), + Sem(Sem) {} + + auto Diag(SourceLocation Loc, unsigned DiagID) { + return Sem.Diag(Loc, DiagID); + } + + void applyOriginal(SemaTransformedTree *L) {} + + void applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + SemaTransformedTree *OuterLoop, + SemaTransformedTree *InnerLoop) {} + void applyUnroll(LoopUnrollTransform *Trans, + SemaTransformedTree *OriginalLoop) {} + void applyDistribution(LoopDistributionTransform *Trans, + SemaTransformedTree *InputLoop) {} + void applyVectorization(LoopVectorizationTransform *Trans, + SemaTransformedTree *InputLoop) {} + void applyInterleaving(LoopInterleavingTransform *Trans, + SemaTransformedTree *InputLoop) {} + + void inheritLoopAttributes(SemaTransformedTree *Dst, SemaTransformedTree *Src, + bool IsMeta, bool IsSuccessor) {} + + void finalize(NodeTy *Root) {} +}; + +} // namespace clang +#endif /* LLVM_CLANG_SEMA_SEMATRANSFORM_H */ diff --git a/clang/lib/AST/ASTTypeTraits.cpp b/clang/lib/AST/ASTTypeTraits.cpp --- a/clang/lib/AST/ASTTypeTraits.cpp +++ b/clang/lib/AST/ASTTypeTraits.cpp @@ -42,6 +42,9 @@ {NKI_None, "OMPClause"}, #define OPENMP_CLAUSE(TextualSpelling, Class) {NKI_OMPClause, #Class}, #include "clang/Basic/OpenMPKinds.def" + {NKI_TransformClause, "TransformClause"}, +#define TRANSFORM_CLAUSE(Keyword, Name) {NKI_##Name##Clause, #Name "Clause"}, +#include "clang/AST/TransformClauseKinds.def" }; bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const { @@ -125,6 +128,18 @@ llvm_unreachable("invalid stmt kind"); } +ASTNodeKind ASTNodeKind::getFromNode(const TransformClause &C) { + switch (C.getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind ::Name##Kind: \ + return ASTNodeKind(NKI_##Name##Clause); +#include "clang/AST/TransformClauseKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("unexpected transform kind"); + } + llvm_unreachable("invalid transform kind"); +} + void DynTypedNode::print(llvm::raw_ostream &OS, const PrintingPolicy &PP) const { if (const TemplateArgument *TA = get()) diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp --- a/clang/lib/AST/JSONNodeDumper.cpp +++ b/clang/lib/AST/JSONNodeDumper.cpp @@ -167,6 +167,10 @@ void JSONNodeDumper::Visit(const OMPClause *C) {} +void JSONNodeDumper::Visit(const TransformClause *C) {} + +void JSONNodeDumper::Visit(const Transform *T) {} + void JSONNodeDumper::Visit(const BlockDecl::Capture &C) { JOS.attribute("kind", "Capture"); attributeOnlyIfTrue("byref", C.isByRef()); diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -30,6 +30,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" @@ -635,6 +636,19 @@ if (Policy.IncludeNewlines) OS << NL; } +void StmtPrinter::VisitTransformExecutableDirective( + TransformExecutableDirective *Node) { + Indent() << "#pragma clang transform " + << Transform ::getTransformDirectiveKeyword( + Node->getTransformKind()); + for (TransformClause *Clause : Node->clauses()) { + OS << ' '; + Clause->print(OS, Policy); + } + OS << NL; + PrintStmt(Node->getAssociated()); +} + //===----------------------------------------------------------------------===// // OpenMP directives printing methods //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -774,6 +774,11 @@ } } +void StmtProfiler::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { + VisitStmt(S); +} + void StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) { VisitStmt(S); diff --git a/clang/lib/AST/StmtTransform.cpp b/clang/lib/AST/StmtTransform.cpp --- a/clang/lib/AST/StmtTransform.cpp +++ b/clang/lib/AST/StmtTransform.cpp @@ -17,6 +17,32 @@ using namespace clang; +TransformExecutableDirective *TransformExecutableDirective::create( + ASTContext &Ctx, SourceRange Range, Stmt *Associated, + ArrayRef Clauses, Transform::Kind TransKind) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(Clauses.size())); + return new (Mem) + TransformExecutableDirective(Range, Associated, Clauses, TransKind); +} + +TransformExecutableDirective * +TransformExecutableDirective::createEmpty(ASTContext &Ctx, + unsigned NumClauses) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(NumClauses)); + return new (Mem) TransformExecutableDirective(NumClauses); +} + +llvm::StringRef TransformClause::getClauseName(Kind K) { + assert(K >= UnknownKind); + assert(K <= LastKind); + const char *Names[LastKind + 1] = { + "Unknown", +#define TRANSFORM_CLAUSE(Keyword, Name) #Name, +#include "clang/AST/TransformClauseKinds.def" + }; + return Names[K]; +} + bool TransformClause::isValidForTransform(Transform::Kind TransformKind, TransformClause::Kind ClauseKind) { switch (TransformKind) { @@ -56,6 +82,56 @@ return ClauseKeyword[ClauseKind - 1]; } +TransformClause ::child_range TransformClause ::children() { + switch (getKind()) { + case UnknownKind: + llvm_unreachable("Unknown child"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return static_cast(this)->children(); +#include "clang/AST/TransformClauseKinds.def" + } + llvm_unreachable("Unhandled clause kind"); +} + +void TransformClause ::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static decltype(&TransformClause::print) PrintFuncs[LastKind] = { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + static_cast(&Name##Clause ::print), +#include "clang/AST/TransformClauseKinds.def" + }; + (this->*PrintFuncs[getKind() - 1])(OS, Policy); +} + +void FullClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "full"; +} + +void PartialClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "partial("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void WidthClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "width("; + Width->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void FactorClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "factor("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + const Stmt *clang::getAssociatedLoop(const Stmt *S) { switch (S->getStmtClass()) { case Stmt::ForStmtClass: @@ -63,6 +139,9 @@ case Stmt::DoStmtClass: case Stmt::CXXForRangeStmtClass: return S; + case Stmt::TransformExecutableDirectiveClass: + return getAssociatedLoop( + cast(S)->getAssociated()); default: return nullptr; } diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -320,6 +320,32 @@ OS << " "; } +void TextNodeDumper::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { +#if 0 + if (S) + AddChild([=] { + for (TransformClause *C : S->clauses()) + Visit( C); + }); +#endif +} + +void TextNodeDumper::Visit(const TransformClause *C) { + if (!C) { + ColorScope Color(OS, ShowColors, NullColor); + OS << "<<>> TransformClause"; + return; + } + { + ColorScope Color(OS, ShowColors, AttrColor); + StringRef ClauseName = TransformClause::getClauseName(C->getKind()); + OS << ClauseName << "Clause"; + } + dumpPointer(C); + dumpSourceRange(C->getRange()); +} + void TextNodeDumper::Visit(const GenericSelectionExpr::ConstAssociation &A) { const TypeSourceInfo *TSI = A.getTypeSourceInfo(); if (TSI) { diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -358,6 +358,9 @@ EmitOMPTargetTeamsDistributeSimdDirective( cast(*S)); break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); + break; } } diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -14159,6 +14159,8 @@ "Leftover expressions for odr-use checking"); } + HandleLoopTransformations(FD); + if (!IsInstantiation) PopDeclContext(); diff --git a/clang/lib/Sema/SemaTransform.cpp b/clang/lib/Sema/SemaTransform.cpp --- a/clang/lib/Sema/SemaTransform.cpp +++ b/clang/lib/Sema/SemaTransform.cpp @@ -10,8 +10,11 @@ // //===----------------------------------------------------------------------===// +#include "clang/Sema/SemaTransform.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtTransform.h" +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Analysis/TransformedTree.h" #include "clang/Basic/Transform.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" @@ -20,30 +23,65 @@ using namespace clang; +struct SemaExtractTransform : ExtractTransform { + Sema &Sem; + SemaExtractTransform(TransformExecutableDirective *Directive, Sema &Sem) + : ExtractTransform(Sem.getASTContext(), Directive), Sem(Sem) {} + + auto Diag(SourceLocation Loc, unsigned DiagID) { + return Sem.Diag(Loc, DiagID); + } +}; + StmtResult Sema::ActOnLoopTransformDirective(Transform::Kind Kind, llvm::ArrayRef Clauses, Stmt *AStmt, SourceRange Loc) { - // TOOD: implement - return StmtError(); + const Stmt *Loop = getAssociatedLoop(AStmt); + if (!Loop) + return StmtError( + Diag(Loc.getBegin(), diag::err_sema_transform_expected_loop)); + + auto *Result = + TransformExecutableDirective::create(Context, Loc, AStmt, Clauses, Kind); + + // Emit errors and warnings. + SemaExtractTransform VerifyTransform(Result, *this); + VerifyTransform.createTransform(); + + return Result; } TransformClause *Sema::ActOnFullClause(SourceRange Loc) { - // TOOD: implement - return nullptr; + return FullClause::create(Context, Loc); } TransformClause *Sema::ActOnPartialClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + return PartialClause::create(Context, Loc, Factor); } TransformClause *Sema::ActOnWidthClause(SourceRange Loc, Expr *Width) { - // TOOD: implement - return nullptr; + return WidthClause::create(Context, Loc, Width); } TransformClause *Sema::ActOnFactorClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + return FactorClause::create(Context, Loc, Factor); +} + +void Sema::HandleLoopTransformations(FunctionDecl *FD) { + if (!FD || FD->isInvalidDecl()) + return; + + // Note: this is called on a template-code and the instantiated code. + llvm::DenseMap StmtToTree; + llvm::SmallVector AllNodes; + llvm::SmallVector AllTransforms; + SemaTransformedTreeBuilder Builder(getASTContext(), getLangOpts(), AllNodes, + AllTransforms, *this); + Builder.computeTransformedStructure(FD->getBody(), StmtToTree); + + for (auto N : AllNodes) + delete N; + for (auto T : AllTransforms) + delete T; } diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -27,6 +27,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/Sema/Designator.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Ownership.h" @@ -344,6 +345,51 @@ /// \returns the transformed statement. StmtResult TransformStmt(Stmt *S, StmtDiscardKind SDK = SDK_Discarded); + Transform *TransformTransform(Transform *T) { return T; } + + TransformClause *TransformTransformClause(TransformClause *S); + + TransformClause *TransformFullClause(FullClause *C) { + return RebuildFullClause(C->getRange()); + } + + TransformClause *RebuildFullClause(SourceRange Range) { + return getSema().ActOnFullClause(Range); + } + + TransformClause *TransformPartialClause(PartialClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildPartialClause(C->getRange(), E.get()); + } + + TransformClause *RebuildPartialClause(SourceRange Range, Expr *Factor) { + return getSema().ActOnPartialClause(Range, Factor); + } + + TransformClause *TransformWidthClause(WidthClause *C) { + ExprResult E = getDerived().TransformExpr(C->getWidth()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildWidthClause(C->getRange(), E.get()); + } + + TransformClause *RebuildWidthClause(SourceRange Range, Expr *Width) { + return getSema().ActOnWidthClause(Range, Width); + } + + TransformClause *TransformFactorClause(FactorClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildFactorClause(C->getRange(), E.get()); + } + + TransformClause *RebuildFactorClause(SourceRange Range, Expr *Factor) { + return getSema().ActOnFactorClause(Range, Factor); + } + /// Transform the given statement. /// /// By default, this routine transforms a statement by delegating to the @@ -1506,6 +1552,17 @@ return getSema().BuildObjCAtThrowStmt(AtLoc, Operand); } + StmtResult + RebuildTransformExecutableDirective(Transform::Kind Kind, + llvm::ArrayRef Clauses, + Stmt *AStmt, SourceRange Loc) { + StmtResult Result = + getSema().ActOnLoopTransformDirective(Kind, Clauses, AStmt, Loc); + assert(!Result.isUsable() || + isa(Result.get())); + return Result; + } + /// Build a new OpenMP executable directive. /// /// By default, performs semantic analysis to build the new statement. @@ -7878,6 +7935,42 @@ return S; } +template +TransformClause * +TreeTransform::TransformTransformClause(TransformClause *S) { + if (!S) + return S; + + switch (S->getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return getDerived().Transform##Name##Clause(cast(S)); +#include "clang/AST/TransformClauseKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Should not be unknown"); + } + + return S; +} + +template +StmtResult TreeTransform::TransformTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm::SmallVector TClauses; + for (TransformClause *C : D->clauses()) { + TransformClause *TClause = getDerived().TransformTransformClause(C); + TClauses.push_back(TClause); + } + + Stmt *AStmt = D->getAssociated(); + StmtResult TBody = getDerived().TransformStmt(AStmt); + assert(TBody.isUsable()); + + StmtResult TDirective = getDerived().RebuildTransformExecutableDirective( + D->getTransformKind(), TClauses, TBody.get(), D->getRange()); + return TDirective; +} + //===----------------------------------------------------------------------===// // OpenMP directive transformation //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2017,6 +2017,15 @@ E->SrcExpr = Record.readSubExpr(); } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtReader::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} + //===----------------------------------------------------------------------===// // OpenMP Directives. //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -1971,6 +1971,14 @@ Record.AddSourceLocation(S->getLeaveLoc()); Code = serialization::STMT_SEH_LEAVE; } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtWriter::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} //===----------------------------------------------------------------------===// // OpenMP Directives. diff --git a/clang/test/AST/ast-dump-transform-unroll.c b/clang/test/AST/ast-dump-transform-unroll.c new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-dump-transform-unroll.c @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -ast-dump %s | FileCheck %s + +void unroll_full(int n) { +#pragma clang transform unroll full + for (int i = 0; i < 4; i+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unroll_full +// CHECK: TransformExecutableDirective +// CHECK-NEXT: FullClause +// CHECK-NEXT: ForStmt + + +void unroll_partial(int n) { +#pragma clang transform unroll partial(4) + for (int i = 0; i < n; i+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unroll_partial +// CHECK: TransformExecutableDirective +// CHECK-NEXT: PartialClause +// CHECK-NEXT: IntegerLiteral +// CHECK-NEXT: ForStmt diff --git a/clang/test/AST/ast-dump-transform-unrollandjam.c b/clang/test/AST/ast-dump-transform-unrollandjam.c new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-dump-transform-unrollandjam.c @@ -0,0 +1,26 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -ast-dump %s | FileCheck %s + +void unrollandjam_heuristic(int n) { +#pragma clang transform unrollandjam + for (int i = 0; i < n; i+=1) + for (int j = 0; j < n; j+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unrollandjam_heuristic +// CHECK: TransformExecutableDirective +// CHECK-NEXT: ForStmt + + +void unrollandjam_partial(int n) { +#pragma clang transform unrollandjam partial(4) + for (int i = 0; i < n; i+=1) + for (int j = 0; j < n; j+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unrollandjam_partial +// CHECK: TransformExecutableDirective +// CHECK-NEXT: PartialClause +// CHECK-NEXT: IntegerLiteral +// CHECK-NEXT: ForStmt diff --git a/clang/test/AST/ast-print-pragma-transform-distribute.cpp b/clang/test/AST/ast-print-pragma-transform-distribute.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-distribute.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void distribute(int *List, int Length) { +// CHECK: #pragma clang transform distribute +#pragma clang transform distribute +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-interleave.cpp b/clang/test/AST/ast-print-pragma-transform-interleave.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-interleave.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform interleave factor(4) +#pragma clang transform interleave factor(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-unroll.cpp b/clang/test/AST/ast-print-pragma-transform-unroll.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-unroll.cpp @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unroll partial(4) +#pragma clang transform unroll partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; + +// CHECK: #pragma clang transform unroll full +#pragma clang transform unroll full +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp b/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unrollandjam partial(4) +#pragma clang transform unrollandjam partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + for (int j = 0; j < Length; j += 1) + List[i] += j; +} diff --git a/clang/test/AST/ast-print-pragma-transform-vectorize.cpp b/clang/test/AST/ast-print-pragma-transform-vectorize.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-vectorize.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform vectorize width(4) +#pragma clang transform vectorize width(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/SemaCXX/pragma-transform-interleave.cpp b/clang/test/SemaCXX/pragma-transform-interleave.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-interleave.cpp @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void interleave(int *List, int Length, int Value) { + +/* expected-error@+1 {{the factor clause can be specified at most once}} */ +#pragma clang transform interleave factor(4) factor(4) + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform interleave factor(-42) + for (int i = 0; i < Length; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-legacymix.cpp b/clang/test/SemaCXX/pragma-transform-legacymix.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-legacymix.cpp @@ -0,0 +1,62 @@ +// RUN: %clang_cc1 -std=c++11 -fopenmp -fexperimental-transform-pragma -fsyntax-only -verify %s + +void legacymix(int *List, int Length, int Value) { + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma unroll'}} */ +#pragma unroll +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma clang loop'}} */ +#pragma clang loop unroll(enable) +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{expected loop after transformation pragma}} */ +#pragma clang transform unroll +#pragma clang loop unroll(enable) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{expected loop after transformation pragma}} */ +#pragma clang transform unroll +#pragma clang loop unroll(disable) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma clang loop'}} */ +#pragma clang loop unroll(disable) +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unrollandjam partial(2) + for (int i = 0; i < 8; i++) +#pragma clang loop unroll(disable) + for (int j = 0; j < 16; j++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < 8; i++) +#pragma omp simd + for (int j = 0; j < 16; j++) + List[i] = Value; + +/* expected-error@+2 {{expected loop after transformation pragma}} */ +#pragma clang transform unroll +#pragma omp simd + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{statement after '#pragma omp simd' must be a for loop}} */ +#pragma omp simd +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-unroll.cpp b/clang/test/SemaCXX/pragma-transform-unroll.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-unroll.cpp @@ -0,0 +1,37 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void unroll(int *List, int Length, int Value) { + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform unroll partial(0) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the full clause can be specified at most once}} */ +#pragma clang transform unroll full full + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the partial clause can be specified at most once}} */ +#pragma clang transform unroll partial(4) partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the full and partial clauses are mutually exclusive}} */ +#pragma clang transform unroll full partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{transformation did not find its loop to transform}} */ +#pragma clang transform unroll +#pragma clang transform unroll full + for (int i = 0; i < 8; i++) + List[i] = Value; + +int f = 4; +/* expected-error@+1 {{clause expects an integer argument}} */ +#pragma clang transform unroll partial(f) + for (int i = 0; i < Length; i+=1) + List[i] = i; + +} diff --git a/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp b/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void unrollandjam(int *List, int Length, int Value) { +/* expected-error@+1 {{the partial clause can be specified at most once}} */ +#pragma clang transform unrollandjam partial(4) partial(4) + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + List[i] += j*Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) { + for (int j = 0; j < Length; j++) + List[i] += j*Value; + for (int j = 0; j < Length; j++) + List[i] += j*Value; + } + +/* expected-error@+1 {{inner loop of unroll-and-jam is not the innermost}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + for (int k = 0; k < Length; k++) + List[i] += k + j*Value; +} diff --git a/clang/test/SemaCXX/pragma-transform-vectorize.cpp b/clang/test/SemaCXX/pragma-transform-vectorize.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-vectorize.cpp @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void vectorize(int *List, int Length, int Value) { +/* expected-error@+1 {{the width clause can be specified at most once}} */ +#pragma clang transform vectorize width(4) width(4) + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform vectorize width(-42) + for (int i = 0; i < Length; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-wrongorder.cpp b/clang/test/SemaCXX/pragma-transform-wrongorder.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-wrongorder.cpp @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void wrongorder(int *List, int Length, int Value) { + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform distribute +#pragma clang transform vectorize + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform vectorize +#pragma clang transform unrollandjam + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + List[i] += j; + +} diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -735,6 +735,9 @@ break; case Stmt::BuiltinBitCastExprClass: K = CXCursor_BuiltinBitCastExpr; + break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); } CXCursor C = { K, 0, { Parent, S, TU } };