Index: include/clang/AST/ASTNodeTraverser.h =================================================================== --- include/clang/AST/ASTNodeTraverser.h +++ include/clang/AST/ASTNodeTraverser.h @@ -49,6 +49,8 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); + void Visit(const Transform *T); }; */ template @@ -205,6 +207,22 @@ }); } + void Visit(const TransformClause *C) { + getNodeDelegate().AddChild([=] { + getNodeDelegate().Visit(C); + for (const auto *S : C->children()) + Visit(S); + }); + } + + void Visit(const Transform *T) { + getNodeDelegate().AddChild([=] { + getNodeDelegate().Visit(T); + for (const auto *S : T->children()) + Visit(S); + }); + } + void Visit(const ast_type_traits::DynTypedNode &N) { // FIXME: Improve this with a switch or a visitor pattern. if (const auto *D = N.get()) @@ -221,6 +239,10 @@ Visit(C); else if (const auto *T = N.get()) Visit(*T); + else if (const auto *C = N.get()) + Visit(C); + else if (const auto *T = N.get()) + Visit(T); } void dumpDeclContext(const DeclContext *DC) { @@ -603,6 +625,12 @@ Visit(C); } + void + VisitTransformExecutableDirective(const TransformExecutableDirective *Node) { + for (const auto *C : Node->clauses()) + Visit(C); + } + void VisitInitListExpr(const InitListExpr *ILE) { if (auto *Filler = ILE->getArrayFiller()) { Visit(Filler, "array_filler"); Index: include/clang/AST/ASTTypeTraits.h =================================================================== --- include/clang/AST/ASTTypeTraits.h +++ include/clang/AST/ASTTypeTraits.h @@ -20,6 +20,7 @@ #include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/OpenMPClause.h" #include "clang/AST/Stmt.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/LLVM.h" @@ -71,6 +72,8 @@ static ASTNodeKind getFromNode(const Stmt &S); static ASTNodeKind getFromNode(const Type &T); static ASTNodeKind getFromNode(const OMPClause &C); + static ASTNodeKind getFromNode(const TransformClause &C); + static ASTNodeKind getFromNode(const Transform &T); /// \} /// Returns \c true if \c this and \c Other represent the same kind. @@ -152,6 +155,12 @@ NKI_OMPClause, #define OPENMP_CLAUSE(TextualSpelling, Class) NKI_##Class, #include "clang/Basic/OpenMPKinds.def" + NKI_TransformClause, +#define TRANSFORM_CLAUSE(Keyword, Name) NKI_##Name##Clause, +#include "clang/AST/TransformKinds.def" + NKI_Transform, +#define TRANSFORM_DIRECTIVE(Keyworld, Name) NKI_##Name##Transform, +#include "clang/AST/TransformKinds.def" NKI_NumberOfKinds }; @@ -208,6 +217,10 @@ #include "clang/AST/TypeNodes.inc" #define OPENMP_CLAUSE(TextualSpelling, Class) KIND_TO_KIND_ID(Class) #include "clang/Basic/OpenMPKinds.def" +#define TRANSFORM_CLAUSE(Keyword, Name) KIND_TO_KIND_ID(Name##Clause) +#include "clang/AST/TransformKinds.def" +#define TRANSFORM_DIRECTIVE(Keyworld, Name) KIND_TO_KIND_ID(Name##Transform) +#include "clang/AST/TransformKinds.def" #undef KIND_TO_KIND_ID inline raw_ostream &operator<<(raw_ostream &OS, ASTNodeKind K) { Index: include/clang/AST/JSONNodeDumper.h =================================================================== --- include/clang/AST/JSONNodeDumper.h +++ include/clang/AST/JSONNodeDumper.h @@ -198,6 +198,8 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); + void Visit(const Transform *T); void VisitTypedefType(const TypedefType *TT); void VisitFunctionType(const FunctionType *T); Index: include/clang/AST/RecursiveASTVisitor.h =================================================================== --- include/clang/AST/RecursiveASTVisitor.h +++ include/clang/AST/RecursiveASTVisitor.h @@ -33,6 +33,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TemplateName.h" #include "clang/AST/Type.h" @@ -536,6 +537,16 @@ bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); + bool TraverseTransformClause(TransformClause *C); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + bool Visit##Name##Clause(Name##Clause *C); +#include "clang/AST/TransformKinds.def" + + bool TraverseTransform(Transform *T); +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + bool Visit##Name##Transform(Name##Transform *T); +#include "clang/AST/TransformKinds.def" + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue); bool PostVisitStmt(Stmt *S); }; @@ -2682,6 +2693,12 @@ // Traverse OpenCL: AsType, Convert. DEF_TRAVERSE_STMT(AsTypeExpr, {}) +DEF_TRAVERSE_STMT(TransformExecutableDirective, { + TRY_TO(TraverseTransform(S->getTransform())); + for (auto *C : S->clauses()) + TRY_TO(TraverseTransformClause(C)); +}) + // OpenMP directives. template bool RecursiveASTVisitor::TraverseOMPExecutableDirective( @@ -2847,6 +2864,38 @@ DEF_TRAVERSE_STMT(OMPTargetTeamsDistributeSimdDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +template +bool RecursiveASTVisitor::TraverseTransform(Transform *T) { + if (!T) + return true; + switch (T->getKind()) { + case Transform::Kind::UnknownKind: + llvm_unreachable("Cannot process unknown transformation"); +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + case Transform::Kind::Name##Kind: \ + TRY_TO(Visit##Name##Transform(static_cast(T))); \ + break; +#include "clang/AST/TransformKinds.def" + } + return true; +} + +template +bool RecursiveASTVisitor::TraverseTransformClause(TransformClause *C) { + if (!C) + return true; + switch (C->getKind()) { + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Cannot process unknown clause"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + TRY_TO(Visit##Name##Clause(static_cast(C))); \ + break; +#include "clang/AST/TransformKinds.def" + } + return true; +} + // OpenMP clauses. template bool RecursiveASTVisitor::TraverseOMPClause(OMPClause *C) { @@ -3344,6 +3393,37 @@ return true; } +template +bool RecursiveASTVisitor::VisitFullClause(FullClause *C) { + return true; +} + +template +bool RecursiveASTVisitor::VisitPartialClause(PartialClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + +template +bool RecursiveASTVisitor::VisitWidthClause(WidthClause *C) { + TRY_TO(TraverseStmt(C->getWidth())); + return true; +} + +template +bool RecursiveASTVisitor::VisitFactorClause(FactorClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + template \ + bool RecursiveASTVisitor::Visit##Name##Transform( \ + Name##Transform *T) { \ + return true; \ + } +#include "clang/AST/TransformKinds.def" + // FIXME: look at the following tricky-seeming exprs to see if we // need to recurse on anything. These are ones that have methods // returning decls or qualtypes or nestednamespecifier -- though I'm Index: include/clang/AST/StmtTransform.h =================================================================== --- include/clang/AST/StmtTransform.h +++ include/clang/AST/StmtTransform.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_AST_STMTTRANSFROM_H #define LLVM_CLANG_AST_STMTTRANSFROM_H +#include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/AST/Transform.h" #include "llvm/Support/raw_ostream.h" @@ -33,7 +34,171 @@ TransformClause::Kind ClauseKind); static Kind getClauseKind(Transform::Kind TransformKind, llvm::StringRef Str); - // TODO: implement +private: + Kind ClauseKind; + SourceRange Loc; + +protected: + TransformClause(Kind K, SourceRange Loc) : ClauseKind(K), Loc(Loc) {} + TransformClause(Kind K) : ClauseKind(K) {} + +public: + Kind getKind() const { return ClauseKind; } + + SourceRange getLoc() const { return Loc; } + SourceLocation getBeginLoc() const { return Loc.getBegin(); } + SourceLocation getEndLoc() const { return Loc.getEnd(); } + + void setLoc(SourceLocation BeginLoc, SourceLocation EndLoc) { + Loc = SourceRange(BeginLoc, EndLoc); + } + void setLoc(SourceRange L) { Loc = L; } + + using child_iterator = Stmt::child_iterator; + using const_child_iterator = Stmt::const_child_iterator; + using child_range = Stmt::child_range; + using const_child_range = Stmt::const_child_range; + + child_range children(); + const_child_range children() const { + auto Children = const_cast(this)->children(); + return const_child_range(Children.begin(), Children.end()); + } + + static llvm::StringRef getClauseName(Kind K); + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class TransformClauseImpl : public TransformClause { +protected: + TransformClauseImpl(Kind ClauseKind, SourceRange Range) + : TransformClause(ClauseKind, Range) {} + explicit TransformClauseImpl(Kind ClauseKind) : TransformClause(ClauseKind) {} + +public: + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const { + llvm_unreachable("implement the print function"); + } + + child_range children() { + // By default, clauses have no children. + return child_range(child_iterator(), child_iterator()); + } +}; + +class FullClause final : public TransformClauseImpl { +private: + explicit FullClause(SourceRange Range) + : TransformClauseImpl(TransformClause::FullKind, Range) {} + FullClause() : TransformClauseImpl(TransformClause::FullKind) {} + +public: + static bool classof(const FullClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FullKind; + } + + static FullClause *create(ASTContext &Context, SourceRange Range) { + return new (Context) FullClause(Range); + } + static FullClause *createEmpty(ASTContext &Context) { + return new (Context) FullClause(); + } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class PartialClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + PartialClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::PartialKind, Range), + Factor(Factor) {} + PartialClause() : TransformClauseImpl(TransformClause::PartialKind) {} + +public: + static bool classof(const PartialClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == PartialKind; + } + + static PartialClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) PartialClause(Range, Factor); + } + static PartialClause *createEmpty(ASTContext &Context) { + return new (Context) PartialClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class WidthClause final : public TransformClauseImpl { +private: + Stmt *Width; + + WidthClause(SourceRange Range, Expr *Width) + : TransformClauseImpl(TransformClause::WidthKind, Range), Width(Width) {} + WidthClause() : TransformClauseImpl(TransformClause::WidthKind) {} + +public: + static bool classof(const WidthClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == WidthKind; + } + + static WidthClause *create(ASTContext &Context, SourceRange Range, + Expr *Width) { + return new (Context) WidthClause(Range, Width); + } + static WidthClause *createEmpty(ASTContext &Context) { + return new (Context) WidthClause(); + } + + child_range children() { return child_range(&Width, &Width + 1); } + + Expr *getWidth() const { return cast(Width); } + void setWidth(Expr *E) { Width = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class FactorClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + FactorClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::FactorKind, Range), + Factor(Factor) {} + FactorClause() : TransformClauseImpl(TransformClause::FactorKind) {} + +public: + static bool classof(const FactorClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FactorKind; + } + + static FactorClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) FactorClause(Range, Factor); + } + static FactorClause *createEmpty(ASTContext &Context) { + return new (Context) FactorClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; }; /// Represents @@ -41,8 +206,108 @@ /// #pragma clang transform /// /// in the AST. -class TransformExecutableDirective final { - // TODO: implement +class TransformExecutableDirective final + : public Stmt, + private llvm::TrailingObjects { +public: + friend TransformClause; + friend TrailingObjects; + +private: + SourceRange Range; + Stmt *Associated = nullptr; + Transform *Trans = nullptr; + Transform::Kind TransKind = Transform::Kind::UnknownKind; + unsigned NumClauses; + +protected: + explicit TransformExecutableDirective(SourceRange Range, Stmt *Associated, + Transform *Trans, + ArrayRef Clauses, + Transform::Kind TransKind) + : Stmt(Stmt::TransformExecutableDirectiveClass), Range(Range), + Associated(Associated), Trans(Trans), TransKind(TransKind), + NumClauses(Clauses.size()) { + setClauses(Clauses); + } + explicit TransformExecutableDirective(unsigned NumClauses) + : Stmt(Stmt::StmtClass::TransformExecutableDirectiveClass), + NumClauses(NumClauses) {} + + size_t numTrailingObjects(OverloadToken) const { + return NumClauses; + } + +public: + static bool classof(const TransformExecutableDirective *T) { return true; } + static bool classof(const Stmt *T) { + return T->getStmtClass() == Stmt::TransformExecutableDirectiveClass; + } + + static TransformExecutableDirective * + create(ASTContext &Ctx, SourceRange Range, Stmt *Associated, Transform *Trans, + ArrayRef Clauses, Transform::Kind TransKind); + static TransformExecutableDirective *createEmpty(ASTContext &Ctx, + unsigned NumClauses); + + SourceRange getLoc() const { return Range; } + SourceLocation getBeginLoc() const { return Range.getBegin(); } + SourceLocation getEndLoc() const { return Range.getEnd(); } + void setLoc(SourceRange Loc) { Range = Loc; } + void setLoc(SourceLocation BeginLoc, SourceLocation EndLoc) { + Range = SourceRange(BeginLoc, EndLoc); + } + + Stmt *getAssociated() const { return Associated; } + void setAssociated(Stmt *S) { Associated = S; } + + Transform *getTransform() const { return Trans; } + Transform::Kind getTransformKind() const { return TransKind; } + + child_range children() { return child_range(&Associated, &Associated + 1); } + const_child_range children() const { + return const_child_range(&Associated, &Associated + 1); + } + + unsigned getNumClauses() const { return NumClauses; } + MutableArrayRef clauses() { + return llvm::makeMutableArrayRef(getTrailingObjects(), + NumClauses); + } + ArrayRef clauses() const { + return llvm::makeArrayRef(getTrailingObjects(), + NumClauses); + } + + void setClauses(llvm::ArrayRef List) { + assert(List.size() == NumClauses); + for (auto p : llvm::zip_first(List, clauses())) + std::get<1>(p) = std::get<0>(p); + } + + auto getClausesOfKind(TransformClause::Kind Kind) const { + return llvm::make_filter_range(clauses(), [Kind](TransformClause *Clause) { + return Clause->getKind() == Kind; + }); + } + + template auto getClausesOfKind() const { + return llvm::map_range( + llvm::make_filter_range(clauses(), + [](TransformClause *Clause) { + return isa(Clause); + }), + [](TransformClause *Clause) { return cast(Clause); }); + } + + template + SpecificClause *getFirstClauseOfKind() const { + auto Range = getClausesOfKind(); + if (Range.begin() == Range.end()) + return nullptr; + return *Range.begin(); + } }; const Stmt *getAssociatedLoop(const Stmt *S); Index: include/clang/AST/TextNodeDumper.h =================================================================== --- include/clang/AST/TextNodeDumper.h +++ include/clang/AST/TextNodeDumper.h @@ -172,6 +172,12 @@ void Visit(const OMPClause *C); + void VisitTransformExecutableDirective(const TransformExecutableDirective *S); + + void Visit(const TransformClause *C); + + void Visit(const Transform *T); + void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); Index: include/clang/AST/Transform.h =================================================================== --- include/clang/AST/Transform.h +++ include/clang/AST/Transform.h @@ -14,6 +14,7 @@ #ifndef LLVM_CLANG_AST_TRANSFORM_H #define LLVM_CLANG_AST_TRANSFORM_H +#include "clang/AST/Stmt.h" #include "llvm/ADT/StringRef.h" namespace clang { @@ -33,7 +34,414 @@ static llvm::StringRef getTransformDirectiveKeyword(Kind K); static llvm::StringRef getTransformDirectiveName(Kind K); - // TODO: implement +private: + Kind TransformKind; + SourceRange Loc; + bool IsLegacy; + +public: + Transform(Kind K, SourceRange Loc, bool IsLegacy) + : TransformKind(K), Loc(Loc), IsLegacy(IsLegacy) {} + + Kind getKind() const { return TransformKind; } + static bool classof(const Transform *Trans) { return true; } + + /// Source location of the code transformation directive. + /// @{ + SourceRange getLoc() const { return Loc; } + SourceLocation getBeginLoc() const { return Loc.getBegin(); } + SourceLocation getEndLoc() const { return Loc.getEnd(); } + + void setLoc(SourceLocation BeginLoc, SourceLocation EndLoc) { + Loc = SourceRange(BeginLoc, EndLoc); + } + void setLoc(SourceRange L) { Loc = L; } + /// @} + + /// Non-legacy directives originate from a #pragma clang transform directive. + /// Legacy transformations are introduced by #pragma clang loop, + /// #pragma omp simd, #pragma unroll, etc. + /// Differences include: + /// * Legacy directives transformation execution order is defined by the + /// compiler. + /// * Some warnings that clang historically did not warn about are disabled. + /// * Some differences of the emitted loop metadata for compatibility. + bool isLegacy() const { return IsLegacy; } + + using child_iterator = Stmt::child_iterator; + using const_child_iterator = Stmt::const_child_iterator; + using child_range = Stmt::child_range; + using const_child_range = Stmt::const_child_range; + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range children() const { + return const_child_range(child_iterator(), child_iterator()); + } + + /// Each transformation defines how many loops it consumes and generates. + /// Users of this class can store arrays holding the information regarding the + /// loops, such as pointer to the AST node or the loop name. The index in this + /// array is its "role". + /// @{ + int getNumInputs() const; + int getNumFollowups() const; + /// @} + + /// The "all" follow-up role is a meta output whose' attributes are added to + /// all generated loops. + bool isAllRole(int R) const { return R == 0; } + + /// Used to warn users that the current LLVM pass pipeline cannot apply + /// arbitrary transformation orders yet. + int getLoopPipelineStage() const; +}; + +/// Default implementation of compile-time inherited methods to avoid infinite +/// recursion. +class TransformImpl : public Transform { +public: + TransformImpl(Kind K, SourceRange Loc, bool IsLegacy) + : Transform(K, Loc, IsLegacy) {} + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 1; } +}; + +class LoopDistributionTransform final : public TransformImpl { +private: + LoopDistributionTransform(SourceRange Loc, bool IsLegacy) + : TransformImpl(LoopDistributionKind, Loc, IsLegacy) {} + +public: + static bool classof(const LoopDistributionTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopDistributionKind; + } + + static LoopDistributionTransform *create(SourceRange Loc, bool IsLegacy) { + return new LoopDistributionTransform(Loc, IsLegacy); + } + + enum Input { + InputToDistribute, + }; + enum Followup { + FollowupAll, + }; +}; + +class LoopVectorizationTransform final : public TransformImpl { +private: + llvm::Optional EnableVectorization; + int VectorizeWidth; + llvm::Optional VectorizePredicateEnable; + + LoopVectorizationTransform(SourceRange Loc, + llvm::Optional EnableVectorization, + int VectorizeWidth, + llvm::Optional VectorizePredicateEnable) + : TransformImpl(LoopVectorizationKind, Loc, true), + EnableVectorization(EnableVectorization), + VectorizeWidth(VectorizeWidth), + VectorizePredicateEnable(VectorizePredicateEnable) {} + +public: + static bool classof(const LoopVectorizationTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopVectorizationKind; + } + + static LoopVectorizationTransform * + Create(SourceRange Loc, llvm::Optional EnableVectorization, + int VectorizeWidth, llvm::Optional VectorizePredicateEnable) { + assert(EnableVectorization.getValueOr(true)); + return new LoopVectorizationTransform( + Loc, EnableVectorization, VectorizeWidth, VectorizePredicateEnable); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + llvm::Optional isVectorizationEnabled() const { + return EnableVectorization; + } + + int getWidth() const { return VectorizeWidth; } + + llvm::Optional isPredicateEnabled() const { + return VectorizePredicateEnable; + } +}; + +class LoopInterleavingTransform final : public TransformImpl { +private: + llvm::Optional EnableInterleaving; + int InterleaveCount; + + LoopInterleavingTransform(SourceRange Loc, + llvm::Optional EnableInterleaving, + int InterleaveCount) + : TransformImpl(LoopInterleavingKind, Loc, false), + EnableInterleaving(EnableInterleaving), + InterleaveCount(InterleaveCount) {} + +public: + static bool classof(const LoopInterleavingTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopInterleavingKind; + } + + static LoopInterleavingTransform * + Create(SourceRange Loc, llvm::Optional EnableInterleaving, + int InterleaveCount) { + assert(EnableInterleaving.getValueOr(true)); + return new LoopInterleavingTransform(Loc, EnableInterleaving, + InterleaveCount); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + llvm::Optional isInterleavingEnabled() const { + return EnableInterleaving; + } + + int getInterleaveCount() const { return InterleaveCount; } +}; + +class LoopVectorizationInterleavingTransform final : public TransformImpl { +private: + bool AssumeSafety; + llvm::Optional EnableVectorization; + llvm::Optional EnableInterleaving; + int VectorizeWidth; + llvm::Optional VectorizePredicateEnable; + int InterleaveCount; + + LoopVectorizationInterleavingTransform( + SourceRange Loc, bool IsLegacy, bool AssumeSafety, + llvm::Optional EnableVectorization, + llvm::Optional EnableInterleaving, int VectorizeWidth, + llvm::Optional VectorizePredicateEnable, int InterleaveCount) + : TransformImpl(LoopVectorizationInterleavingKind, Loc, IsLegacy), + AssumeSafety(AssumeSafety), EnableVectorization(EnableVectorization), + EnableInterleaving(EnableInterleaving), VectorizeWidth(VectorizeWidth), + VectorizePredicateEnable(VectorizePredicateEnable), + InterleaveCount(InterleaveCount) {} + +public: + static bool classof(const LoopVectorizationInterleavingTransform *Trans) { + return true; + } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopVectorizationInterleavingKind; + } + + static LoopVectorizationInterleavingTransform * + create(SourceRange Loc, bool Legacy, bool AssumeSafety, + llvm::Optional EnableVectorization, + llvm::Optional EnableInterleaving, int VectorizeWidth, + llvm::Optional VectorizePredicateEnable, int InterleaveCount) { + assert(EnableVectorization.getValueOr(true) || + EnableInterleaving.getValueOr(true)); + return new LoopVectorizationInterleavingTransform( + Loc, Legacy, AssumeSafety, EnableVectorization, EnableInterleaving, + VectorizeWidth, VectorizePredicateEnable, InterleaveCount); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + bool isAssumeSafety() const { return AssumeSafety; } + llvm::Optional isVectorizationEnabled() const { + return EnableVectorization; + } + llvm::Optional isInterleavingEnabled() const { + return EnableInterleaving; + } + + int getWidth() const { return VectorizeWidth; } + llvm::Optional isPredicateEnabled() const { + return VectorizePredicateEnable; + } + int getInterleaveCount() const { return InterleaveCount; } +}; + +class LoopUnrollingTransform final : public TransformImpl { +private: + bool ImplicitEnable; + bool ExplicitEnable; + int64_t Factor; + + LoopUnrollingTransform(SourceRange Loc, bool IsLegacy, bool ImplicitEnable, + bool ExplicitEnable, int Factor) + : TransformImpl(LoopUnrollingKind, Loc, IsLegacy), + ImplicitEnable(ImplicitEnable), ExplicitEnable(ExplicitEnable), + Factor(Factor) {} + +public: + static bool classof(const LoopUnrollingTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollingKind; + } + + static LoopUnrollingTransform *create(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, 0); + } + static LoopUnrollingTransform *createFull(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, -1); + } + static LoopUnrollingTransform *createPartial(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable, + int Factor) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, Factor); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToUnroll, + }; + enum Followup { + FollowupAll, // if not full + FollowupUnrolled, // if not full + FollowupRemainder + }; + + bool hasFollowupRole(int i) { return isFull(); } + int getDefaultSuccessor() const { + return isLegacy() ? FollowupAll : FollowupUnrolled; + } + + bool isImplicitEnable() const { return ImplicitEnable; } + + bool isExplicitEnable() const { return ExplicitEnable; } + + int getFactor() const { return Factor; } + + bool isFull() const { return Factor == -1; } +}; + +class LoopUnrollAndJamTransform final : public TransformImpl { +private: + bool ExplicitEnable; + int Factor; + + LoopUnrollAndJamTransform(SourceRange Loc, bool IsLegacy, bool ExplicitEnable, + int Factor) + : TransformImpl(LoopUnrollAndJamKind, Loc, IsLegacy), + ExplicitEnable(ExplicitEnable), Factor(Factor) {} + +public: + static bool classof(const LoopUnrollAndJamTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollAndJamKind; + } + + static LoopUnrollAndJamTransform *create(SourceRange Loc, bool Legacy, + bool ExplicitEnable) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, 0); + } + static LoopUnrollAndJamTransform *createFull(SourceRange Loc, bool Legacy, + bool ExplicitEnable) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, -1); + } + static LoopUnrollAndJamTransform * + createPartial(SourceRange Loc, bool Legacy, bool ExplicitEnable, int Factor) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, Factor); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputOuter, + InputInner, + }; + enum Followup { FollowupAll, FollowupOuter, FollowupInner }; + + bool isExplicitEnable() const { return ExplicitEnable; } + + int getFactor() const { + assert(Factor >= 0); + return Factor; + } + + bool isFull() const { return Factor == -1; } +}; + +class LoopPipeliningTransform final : public TransformImpl { +private: + int InitiationInterval; + + LoopPipeliningTransform(SourceRange Loc, int InitiationInterval) + : TransformImpl(LoopPipeliningKind, Loc, true), + InitiationInterval(InitiationInterval) {} + +public: + static bool classof(const LoopPipeliningTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopPipeliningKind; + } + + static LoopPipeliningTransform *create(SourceRange Loc, + int InitiationInterval) { + return new LoopPipeliningTransform(Loc, InitiationInterval); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 0; } + enum Input { + InputToPipeline, + }; + + int getInitiationInterval() const { return InitiationInterval; } +}; + +class LoopAssumeParallelTransform final : public TransformImpl { +private: + LoopAssumeParallelTransform(SourceRange Loc) + : TransformImpl(LoopAssumeParallelKind, Loc, false) {} + +public: + static bool classof(const LoopAssumeParallelTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopAssumeParallelKind; + } + + static LoopAssumeParallelTransform *create(SourceRange Loc) { + return new LoopAssumeParallelTransform(Loc); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 1; } + enum Input { + InputParallel, + }; + enum Followup { + FollowupParallel, + }; }; } // namespace clang Index: include/clang/Analysis/AnalysisTransform.h =================================================================== --- /dev/null +++ include/clang/Analysis/AnalysisTransform.h @@ -0,0 +1,1235 @@ +//===--- AnalysisTransform.h -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Applies code transformations +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H +#define LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H + +#include "clang/AST/OpenMPClause.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Basic/DiagnosticSema.h" +#include "llvm/ADT/SmallVector.h" + +namespace clang { +template class TransformedTreeBuilder; + +/// Represents an input of a code representation. +/// Current can reference the input code only by the AST node, but in the future +/// loops can also given identifiers to reference them. +class TransformInput { + const Stmt *StmtInput = nullptr; + + TransformInput(const Stmt *StmtInput) : StmtInput(StmtInput) {} + +public: + TransformInput() {} + + static TransformInput createByStmt(const Stmt *StmtInput) { + assert(StmtInput); + return TransformInput(StmtInput); + } + + bool isByStmt() const { return StmtInput; } + + const Stmt *getStmtInput() const { return StmtInput; } +}; + +/// Represents a transformation together with the input loops. +/// in the future it will also identify the generated loop. +struct NodeTransform { + Transform *Trans = nullptr; + llvm::SmallVector Inputs; + + NodeTransform() {} + NodeTransform(Transform *Trans, TransformInput TopLevelInput) : Trans(Trans) { + assert(Trans->getNumInputs() >= 1); + + Inputs.resize(Trans->getNumInputs()); + setInput(0, TopLevelInput); + }; + + void setInput(int Idx, TransformInput Input) { Inputs[Idx] = Input; } +}; + +/// This class represents a loop in a loop nest to which transformations are +/// applied to. It is intended to be instantiated for it specific purpose, that +/// is SemaTransformedTree for semantic analysis (consistency warnings and +/// errors) and CGTransformedTree for emitting IR. +template class TransformedTree { + template friend class TransformedTreeBuilder; + using NodeTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + +protected: + /// Is this the root node of the loop hierarchy? + bool IsRoot = false; + + /// Nested loops. + llvm::SmallVector Subloops; + + /// Origin of this loop. + /// @{ + + /// If not the result of the transformation, this is the loop statement that + /// this node represents. + Stmt *Original; + + Derived *BasedOn; + int FollowupRole; + int Stage; + + /// @} + + /// Things applied to this loop. + /// @{ + + /// If this is the primary transformation input. + Transform *TransformedBy = nullptr; + + // Primary/secondary transformation input + Derived *PrimaryInput = nullptr; + + llvm::SmallVector Followups; + Derived *Successor = nullptr; + int InputRole = -1; + + bool IsParallel = false; + /// @} + +protected: + TransformedTree(llvm::ArrayRef SubLoops, Derived *BasedOn, + clang::Stmt *Original, int FollowupRole, int Stage) + : Subloops(SubLoops.begin(), SubLoops.end()), Original(Original), + BasedOn(BasedOn), FollowupRole(FollowupRole), Stage(Stage) {} + +public: + ArrayRef getSubLoops() const { return Subloops; } + + Derived *getPrimaryInput() const { return PrimaryInput; } + Transform *getTransformedBy() const { return TransformedBy; } + + Stmt *getOriginal() const { return Original; } + Stmt *getInheritedOriginal() const { + if (Original) + return Original; + if (BasedOn) + return BasedOn->getInheritedOriginal(); + return nullptr; + } + + Derived *getBasedOn() const { return BasedOn; } + + void markParallel() { IsParallel = true; } + bool isParallel() const { return IsParallel; } + + bool isRoot() const { return IsRoot; } + + int getStage() const { return Stage; } + + Derived *getSuccessor() { return Successor; } + + Derived *getLatestSuccessor() { + // If the loop is not being consumed, this is the latest successor. + if (!isTransformationInput()) + return &getDerived(); + // It is possible for a loop consumed into a non-loop, such that it has no + // successor. + if (!Successor) + return nullptr; + return Successor->getLatestSuccessor(); + } + + const Derived *getLatestSuccessor() const { + return const_cast(this)->getLatestSuccessor(); + } + + bool isOriginal() const { return Original; } + + bool isTransformationInput() const { + bool Result = InputRole >= 0; + assert(Result == (PrimaryInput != nullptr)); + return Result; + } + + bool isTransformationFollowup() const { + bool Result = FollowupRole >= 0; + assert(Result == (BasedOn != nullptr)); + return Result; + } + + bool isPrimaryInput() const { + bool Result = (InputRole == 0); + assert(Result == (PrimaryInput == this)); + return PrimaryInput == this; + } + + void applyTransformation(Transform *Trans, + llvm::ArrayRef Followups, + Derived *Successor) { + assert(!isTransformationInput()); + + this->TransformedBy = Trans; + this->Followups.insert(this->Followups.end(), Followups.begin(), + Followups.end()); + this->Successor = Successor; + this->PrimaryInput = &getDerived(); + this->InputRole = 0; // for primary + +#ifndef NDEBUG + assert(isTransformationInput() && isPrimaryInput()); + for (NodeTy *S : Followups) { + assert(S->BasedOn == &getDerived()); + } +#endif + } + + void applySuccessors(Derived *PrimaryInput, int InputRole, + llvm::ArrayRef Followups, + Derived *Successor) { + assert(!isTransformationInput()); + assert(InputRole > 0); + + this->PrimaryInput = PrimaryInput; + this->Followups.insert(this->Followups.end(), Followups.begin(), + Followups.end()); + this->Successor = Successor; + this->InputRole = InputRole; + +#ifndef NDEBUG + assert(isTransformationInput() && !isPrimaryInput()); + for (NodeTy *S : Followups) { + assert(S->BasedOn == &getDerived()); + } +#endif + } +}; + +/// Constructs a loop nest from source and applies transformations on it. +template class TransformedTreeBuilder { + using BuilderTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + ASTContext &ASTCtx; + llvm::SmallVectorImpl &AllNodes; + +private: + /// Build the original loop nest hierarchy from the AST. + void buildPhysicalLoopTree(Stmt *S, SmallVectorImpl &SubLoops, + llvm::DenseMap &StmtToTree) { + if (!S) + return; + + Stmt *Body; + switch (S->getStmtClass()) { + case Stmt::ForStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::WhileStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::DoStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CXXForRangeStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CapturedStmtClass: + buildPhysicalLoopTree(cast(S)->getCapturedStmt(), SubLoops, + StmtToTree); + return; + case Expr::LambdaExprClass: + // Call to getBody materializes its body, children() (which is called in + // the default case) does not. + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + case Expr::BlockExprClass: + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + default: + if (auto *O = dyn_cast(S)) { + if (!O->hasAssociatedStmt()) + return; + Stmt *Associated = O->getAssociatedStmt(); + buildPhysicalLoopTree(Associated, SubLoops, StmtToTree); + return; + } + + for (Stmt *Child : S->children()) + buildPhysicalLoopTree(Child, SubLoops, StmtToTree); + + return; + } + + SmallVector SubSubLoops; + buildPhysicalLoopTree(Body, SubSubLoops, StmtToTree); + + NodeTy *L = getDerived().createPhysical(SubSubLoops, S); + SubLoops.push_back(L); + assert(StmtToTree.count(S) == 0); + StmtToTree[S] = L; + + getDerived().applyOriginal(L); + } + + /// Collect all loop transformations in the function's AST. + class CollectTransformationsVisitor + : public RecursiveASTVisitor { + + Derived &Builder; + llvm::DenseMap &StmtToTree; + + public: + CollectTransformationsVisitor(Derived &Builder, + llvm::DenseMap &StmtToTree) + : Builder(Builder), StmtToTree(StmtToTree) {} + + /// Transformations collected so far. + llvm::SmallVector Transforms; + + bool shouldTraversePostOrder() const { return true; } + + /// Read and apply LoopHint (#pragma clang loop) attributes. + void applyAttributed(const Stmt *S, ArrayRef Attrs, + NodeTy *L) { + + /// State of loop vectorization or unrolling. + enum LVEnableState { Unspecified, Enable, Disable, Full }; + + /// Value for llvm.loop.vectorize.enable metadata. + LVEnableState VectorizeEnable = Unspecified; + bool VectorizeAssumeSafety = false; + + /// Value for llvm.loop.unroll.* metadata (enable, disable, or full). + LVEnableState UnrollEnable = Unspecified; + + /// Value for llvm.loop.unroll_and_jam.* metadata (enable, disable, or + /// full). + LVEnableState UnrollAndJamEnable = Unspecified; + + /// Value for llvm.loop.vectorize.predicate metadata + LVEnableState VectorizePredicateEnable = Unspecified; + + /// Value for llvm.loop.vectorize.width metadata. + unsigned VectorizeWidth = 0; + + /// Value for llvm.loop.interleave.count metadata. + unsigned InterleaveCount = 0; + + /// llvm.unroll. + unsigned UnrollCount = 0; + + /// llvm.unroll. + unsigned UnrollAndJamCount = 0; + + /// Value for llvm.loop.distribute.enable metadata. + LVEnableState DistributeEnable = Unspecified; + + /// Value for llvm.loop.pipeline.disable metadata. + bool PipelineDisabled = false; + + /// Value for llvm.loop.pipeline.iicount metadata. + unsigned PipelineInitiationInterval = 0; + + SourceLocation DistLoc; + SourceLocation VecLoc; + SourceLocation UnrollLoc; + SourceLocation UnrollAndJamLoc; + SourceLocation PipelineLoc; + for (const Attr *Attr : Attrs) { + const LoopHintAttr *LH = dyn_cast(Attr); + const OpenCLUnrollHintAttr *OpenCLHint = + dyn_cast(Attr); + + // Skip non loop hint attributes + if (!LH && !OpenCLHint) { + continue; + } + + LoopHintAttr::OptionType Option = LoopHintAttr::Unroll; + LoopHintAttr::LoopHintState State = LoopHintAttr::Disable; + unsigned ValueInt = 1; + // Translate opencl_unroll_hint attribute argument to + // equivalent LoopHintAttr enums. + // OpenCL v2.0 s6.11.5: + // 0 - enable unroll (no argument). + // 1 - disable unroll. + // other positive integer n - unroll by n. + if (OpenCLHint) { + ValueInt = OpenCLHint->getUnrollHint(); + if (ValueInt == 0) { + State = LoopHintAttr::Enable; + } else if (ValueInt != 1) { + Option = LoopHintAttr::UnrollCount; + State = LoopHintAttr::Numeric; + } + } else if (LH) { + Expr *ValueExpr = LH->getValue(); + if (ValueExpr) { + if (ValueExpr->isValueDependent()) + continue; // Ignore this attribute until it is instantiated. + llvm::APSInt ValueAPS = + ValueExpr->EvaluateKnownConstInt(Builder.ASTCtx); + ValueInt = ValueAPS.getSExtValue(); + } + + Option = LH->getOption(); + State = LH->getState(); + } + + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::VectorizePredicate: + VecLoc = Attr->getLocation(); + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollCount: + UnrollLoc = Attr->getLocation(); + break; + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::UnrollAndJamCount: + UnrollAndJamLoc = Attr->getLocation(); + break; + case LoopHintAttr::Distribute: + DistLoc = Attr->getLocation(); + break; + + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + PipelineLoc = Attr->getLocation(); + break; + } + + switch (State) { + case LoopHintAttr::Disable: + switch (Option) { + case LoopHintAttr::Vectorize: + // Disable vectorization by specifying a width of 1. + VectorizeWidth = 1; + break; + case LoopHintAttr::Interleave: + // Disable interleaving by specifying a count of 1. + InterleaveCount = 1; + break; + case LoopHintAttr::Unroll: + UnrollEnable = Disable; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Disable; + break; + case LoopHintAttr::VectorizePredicate: + VectorizePredicateEnable = Disable; + break; + case LoopHintAttr::Distribute: + DistributeEnable = Disable; + break; + case LoopHintAttr::PipelineDisabled: + PipelineDisabled = true; + break; + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot be disabled."); + break; + } + break; + case LoopHintAttr::Enable: + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + VectorizeEnable = Enable; + break; + case LoopHintAttr::Unroll: + UnrollEnable = Enable; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Enable; + break; + case LoopHintAttr::VectorizePredicate: + VectorizePredicateEnable = Enable; + break; + case LoopHintAttr::Distribute: + DistributeEnable = Enable; + break; + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot enabled."); + break; + } + break; + case LoopHintAttr::AssumeSafety: + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + // Apply "llvm.mem.parallel_loop_access" metadata to load/stores. + VectorizeEnable = Enable; + VectorizeAssumeSafety = true; + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::VectorizePredicate: + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot be used to assume mem safety."); + break; + } + break; + case LoopHintAttr::Full: + switch (Option) { + case LoopHintAttr::Unroll: + UnrollEnable = Full; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Full; + break; + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + case LoopHintAttr::VectorizePredicate: + llvm_unreachable("Options cannot be used with 'full' hint."); + break; + } + break; + case LoopHintAttr::Numeric: + switch (Option) { + case LoopHintAttr::VectorizeWidth: + VectorizeWidth = ValueInt; + break; + case LoopHintAttr::InterleaveCount: + InterleaveCount = ValueInt; + break; + case LoopHintAttr::UnrollCount: + UnrollCount = ValueInt; + break; + case LoopHintAttr::UnrollAndJamCount: + UnrollAndJamCount = ValueInt; + break; + case LoopHintAttr::PipelineInitiationInterval: + PipelineInitiationInterval = ValueInt; + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::VectorizePredicate: + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + llvm_unreachable("Options cannot be assigned a value."); + break; + } + break; + } + } + + bool VectorizeDisabled = + (VectorizeEnable == Disable || VectorizeWidth == 1); + bool InterleaveDisabled = (InterleaveCount == 1); + bool VectorizeInterleaveDisabled = + VectorizeDisabled && InterleaveDisabled && !VectorizeAssumeSafety; + + if (UnrollEnable == Disable) + Builder.disableUnroll(L); + + if (UnrollAndJamEnable == Disable) + Builder.disableUnrollAndJam(L); + + if (DistributeEnable == Disable) + Builder.disableDistribution(L); + + // If the LoopVectorize pass is completely disabled (ie. vectorize and + // interleave). + if (VectorizeInterleaveDisabled) + Builder.disableVectorizeInterleave(L); + + if (PipelineDisabled) + Builder.disablePipelining(L); + + if (UnrollEnable == Full) { + Transform *Trans = + LoopUnrollingTransform::createFull(UnrollLoc, true, false, true); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (DistributeEnable == Enable) { + auto *Trans = LoopDistributionTransform::create(DistLoc, true); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if ((VectorizeEnable == Enable || VectorizeWidth > 0 || + VectorizePredicateEnable != Unspecified || InterleaveCount > 0) && + !VectorizeInterleaveDisabled) { + auto *Trans = LoopVectorizationInterleavingTransform::create( + VecLoc, true, VectorizeAssumeSafety, + VectorizeEnable != Unspecified + ? llvm::Optional(VectorizeEnable == Enable) + : None, + None, VectorizeWidth, + (VectorizePredicateEnable != Unspecified) + ? llvm::Optional(VectorizePredicateEnable == Enable) + : None, + InterleaveCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (UnrollAndJamEnable == Enable || UnrollAndJamEnable == Full || + UnrollAndJamCount > 0) { + LoopUnrollAndJamTransform *Trans; + if (UnrollAndJamEnable == Full) + Trans = LoopUnrollAndJamTransform::createFull(UnrollAndJamLoc, true, + true); + else + Trans = LoopUnrollAndJamTransform::createPartial( + UnrollAndJamLoc, true, UnrollAndJamEnable == Enable, + UnrollAndJamCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (UnrollEnable != Full && (UnrollEnable == Enable || UnrollCount > 0)) { + auto *Trans = LoopUnrollingTransform::createPartial( + UnrollLoc, true, false, UnrollEnable == Enable, UnrollCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (PipelineInitiationInterval > 0) { + auto *Trans = LoopPipeliningTransform::create( + PipelineLoc, PipelineInitiationInterval); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + } + + bool VisitAttributedStmt(const AttributedStmt *S) { + const Stmt *LoopStmt = getAssociatedLoop(S); + + // Not every attributed statement is associated with a loop. + if (!LoopStmt) + return true; + + NodeTy *Node = StmtToTree.lookup(LoopStmt); + assert(Node && "We should have created a node for ever loop"); + applyAttributed(LoopStmt, S->getAttrs(), Node); + return true; + } + + bool + VisitTransformExecutableDirective(const TransformExecutableDirective *S) { + // Can happen in semantic analysis in non-instantiated templates. + if (!S->getTransform()) + return false; + + const Stmt *TheLoop = getAssociatedLoop(S->getAssociated()); + Transforms.emplace_back(S->getTransform(), + TransformInput::createByStmt(TheLoop)); + return true; + } + + bool HandleOMPLoopClauses(OMPLoopDirective *Directive, bool HasTaskloop, + bool HasFor, bool HasSimd) { + const Stmt *TopLevel = getAssociatedLoop(Directive); + + bool IsMonotonic = true; + if (HasFor) { + if (auto *ScheduleClause = + Directive->getSingleClause()) { + OpenMPScheduleClauseKind Schedule = ScheduleClause->getScheduleKind(); + if (Schedule != OMPC_SCHEDULE_unknown && + Schedule != OMPC_SCHEDULE_static) + IsMonotonic = false; + + if (ScheduleClause->getFirstScheduleModifier() == + OMPC_SCHEDULE_MODIFIER_monotonic) + IsMonotonic = true; + if (ScheduleClause->getSecondScheduleModifier() == + OMPC_SCHEDULE_MODIFIER_monotonic) + IsMonotonic = true; + } + + if (auto *OrderedClause = + Directive->getSingleClause()) { + if (!OrderedClause->getNumForLoops()) + IsMonotonic = true; + } + } + + if (HasTaskloop) + IsMonotonic = false; + + int64_t SimdWidth = 0; + bool SimdImplicitParallel = false; + if (HasSimd) { + if (!HasFor) + SimdImplicitParallel = true; + + if (auto *LenClause = Directive->getSingleClause()) { + auto SafelenExpr = LenClause->getSafelen(); + if (!SafelenExpr->isValueDependent() && + SafelenExpr->isEvaluatable(Builder.ASTCtx)) { + llvm::APSInt ValueAPS = + SafelenExpr->EvaluateKnownConstInt(Builder.ASTCtx); + SimdWidth = ValueAPS.getSExtValue(); + } + + // In presence of finite 'safelen', it may be unsafe to mark all + // the memory instructions parallel, because loop-carried + // dependences of 'safelen' iterations are possible. + SimdImplicitParallel = false; + IsMonotonic = true; + } + + if (auto LenClause = Directive->getSingleClause()) { + auto SimdlenExpr = LenClause->getSimdlen(); + if (!SimdlenExpr->isValueDependent() && + SimdlenExpr->isEvaluatable(Builder.ASTCtx)) { + llvm::APSInt ValueAPS = + SimdlenExpr->EvaluateKnownConstInt(Builder.ASTCtx); + SimdWidth = ValueAPS.getSExtValue(); + } + } + } + + bool IsImplicitParallel = !IsMonotonic || SimdImplicitParallel; + + if (HasSimd) { + auto *Trans = LoopVectorizationInterleavingTransform::create( + Directive->getSourceRange(), true, IsImplicitParallel, true, None, + SimdWidth, None, 0); + Transforms.emplace_back(Trans, TransformInput::createByStmt(TopLevel)); + } else if (IsImplicitParallel) { + auto *Assume = + LoopAssumeParallelTransform::create(Directive->getSourceRange()); + Transforms.emplace_back(Assume, TransformInput::createByStmt(TopLevel)); + } + + return true; + } + + bool VisitOMPForDirective(OMPForDirective *For) { + return HandleOMPLoopClauses(For, false, true, false); + } + + bool VisitOMPDistributeParallelForDirective( + OMPDistributeParallelForDirective *L) { + return HandleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPTeamsDistributeParallelForDirective( + OMPTeamsDistributeParallelForDirective *L) { + return HandleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPTargetTeamsDistributeParallelForDirective( + OMPTargetTeamsDistributeParallelForDirective *L) { + return HandleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPSimdDirective(OMPLoopDirective *Simd) { + return HandleOMPLoopClauses(Simd, false, false, true); + } + + bool VisitOMPForSimdDirective(OMPForSimdDirective *ForSimd) { + return HandleOMPLoopClauses(ForSimd, false, true, true); + } + + bool + VisitOMPParallelForSimdDirective(OMPParallelForSimdDirective *ForSimd) { + return HandleOMPLoopClauses(ForSimd, false, true, true); + } + + bool VisitOMPDistributeSimdDirective( + OMPDistributeSimdDirective *DistributeSimd) { + return HandleOMPLoopClauses(DistributeSimd, false, false, true); + } + + bool VisitOMPDistributeParallelForSimdDirective( + OMPDistributeParallelForSimdDirective *DistributeSimd) { + return HandleOMPLoopClauses(DistributeSimd, false, true, true); + } + + bool VisitOMPTeamsDistributeSimdDirective( + OMPTeamsDistributeSimdDirective *TeamsDistributeSimd) { + return HandleOMPLoopClauses(TeamsDistributeSimd, false, false, true); + } + + bool VisitOMPTeamsDistributeParallelForSimdDirective( + OMPTeamsDistributeParallelForSimdDirective + *TeamsDistributeParallelForSimd) { + return HandleOMPLoopClauses(TeamsDistributeParallelForSimd, false, true, + true); + } + + bool VisitOMPTargetTeamsDistributeSimdDirective( + OMPTargetTeamsDistributeSimdDirective *TargetTeamsDistributeSimd) { + return HandleOMPLoopClauses(TargetTeamsDistributeSimd, false, false, + true); + } + + bool VisitOMPTargetTeamsDistributeParallelForSimdDirective( + OMPTargetTeamsDistributeParallelForSimdDirective + *TargetTeamsDistributeParallelForSimd) { + return HandleOMPLoopClauses(TargetTeamsDistributeParallelForSimd, false, + true, true); + } + + bool VisitOMPTaskLoopSimdDirective( + OMPTaskLoopSimdDirective *TaskLoopSimdDirective) { + return HandleOMPLoopClauses(TaskLoopSimdDirective, true, false, true); + } + + bool + VisitOMPTargetSimdDirective(OMPTargetSimdDirective *TargetSimdDirective) { + return HandleOMPLoopClauses(TargetSimdDirective, false, false, true); + } + + bool VisitOMPTargetParallelForSimdDirective( + OMPTargetParallelForSimdDirective *TargetParallelForSimdDirective) { + return HandleOMPLoopClauses(TargetParallelForSimdDirective, false, true, + true); + } + }; + + /// Applies collected transformations to the loop nest representation. + struct TransformApplicator { + Derived &Builder; + llvm::DenseMap> + TransByStmt; + + TransformApplicator(Derived &Builder) : Builder(Builder) {} + + void addNodeTransform(NodeTransform *NT) { + TransformInput &TopLevelInput = NT->Inputs[0]; + TransByStmt[TopLevelInput.getStmtInput()].push_back(NT); + } + + void checkStageOrder(ArrayRef PrevLoops, Transform *NewTrans) { + int NewStage = NewTrans->getLoopPipelineStage(); + if (NewStage == -1) + return; + + for (auto PrevLoop : PrevLoops) { + auto PrevStage = PrevLoop->getStage(); + if (PrevStage > NewStage) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::warn_sema_transform_pass_order); + + // At most one warning per transformation. + return; + } + } + } + + NodeTy *applyTransform(Transform *Trans, NodeTy *MainLoop) { + switch (Trans->getKind()) { + case Transform::Kind::LoopUnrollingKind: + return applyUnrolling(cast(Trans), MainLoop); + case Transform::Kind::LoopUnrollAndJamKind: + return applyUnrollAndJam(cast(Trans), + MainLoop); + case Transform::Kind::LoopDistributionKind: + return applyDistribution(cast(Trans), + MainLoop); + case Transform::Kind::LoopVectorizationInterleavingKind: + return applyVectorizeInterleave( + cast(Trans), MainLoop); + case Transform::Kind::LoopVectorizationKind: + return applyVectorize(cast(Trans), + MainLoop); + case Transform::Kind::LoopInterleavingKind: + return applyInterleave(cast(Trans), + MainLoop); + case Transform::Kind::LoopPipeliningKind: + return applyPipelining(cast(Trans), MainLoop); + case Transform::Kind::LoopAssumeParallelKind: + return applyAssumeParallel(cast(Trans), + MainLoop); + default: + llvm_unreachable("unimplemented transformation"); + } + } + + void inheritLoopAttributes(NodeTy *Dst, NodeTy *Src, bool IsAll, + bool IsSuccessor) { + Builder.inheritLoopAttributes(Dst, Src, IsAll, IsSuccessor); + } + + NodeTy *applyUnrolling(LoopUnrollingTransform *Trans, NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + int Stage = Trans->getLoopPipelineStage(); + + NodeTy *Successor = nullptr; + if (Trans->isFull()) { + // Full unrolling has no followup-loop. + MainLoop->applyTransformation(Trans, {}, nullptr); + } else { + NodeTy *All = + Builder.createFollowup(MainLoop->Subloops, MainLoop, + LoopUnrollingTransform::FollowupAll, Stage); + NodeTy *Unrolled = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopUnrollingTransform::FollowupUnrolled, Stage); + NodeTy *Remainder = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopUnrollingTransform::FollowupRemainder, Stage); + Successor = Trans->isLegacy() ? All : Unrolled; + + inheritLoopAttributes(All, MainLoop, true, All == Successor); + MainLoop->applyTransformation(Trans, {All, Unrolled, Remainder}, + Successor); + } + + Builder.applyUnroll(Trans, MainLoop); + return Successor; + } + + NodeTy *applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + NodeTy *MainLoop) { + // Search for the innermost loop that is being jammed. + NodeTy *Cur = MainLoop; + NodeTy *Inner = nullptr; + if (Cur->Subloops.size() == 1) { + Inner = Cur->Subloops[0]->getLatestSuccessor(); + } else if (!Trans->isLegacy()) { + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (!Trans->isLegacy() && !Inner) { + if (!Trans->isLegacy()) + Builder.Diag( + Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (!Trans->isLegacy() && Inner->Subloops.size() != 0) { + if (!Trans->isLegacy()) + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_not_innermost); + return nullptr; + } + + int Stage = Trans->getLoopPipelineStage(); + + // Having no loop to jam does not make a lot of sense, but fixes + // regression tests. + if (!Inner) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *UnrolledOuter = Builder.createFollowup( + {}, MainLoop, LoopUnrollAndJamTransform::FollowupOuter, Stage); + inheritLoopAttributes(UnrolledOuter, MainLoop, true, false); + + MainLoop->applyTransformation(Trans, {UnrolledOuter}, UnrolledOuter); + Builder.applyUnrollAndJam(Trans, MainLoop, nullptr); + return UnrolledOuter; + } + + checkStageOrder({MainLoop, Inner}, Trans); + + NodeTy *TransformedInner = Builder.createFollowup( + Inner->Subloops, Inner, LoopUnrollAndJamTransform::FollowupInner, + Stage); + inheritLoopAttributes(TransformedInner, Inner, false, false); + + // TODO: Handle full unrolling + NodeTy *UnrolledOuter = Builder.createFollowup( + {Inner}, MainLoop, LoopUnrollAndJamTransform::FollowupOuter, Stage); + inheritLoopAttributes(UnrolledOuter, MainLoop, false, true); + + MainLoop->applyTransformation(Trans, {UnrolledOuter}, UnrolledOuter); + Inner->applySuccessors(MainLoop, LoopUnrollAndJamTransform::InputInner, + {TransformedInner}, TransformedInner); + + Builder.applyUnrollAndJam(Trans, MainLoop, Inner); + return UnrolledOuter; + } + + NodeTy *applyDistribution(LoopDistributionTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + int Stage = Trans->getLoopPipelineStage(); + + NodeTy *All = + Builder.createFollowup(MainLoop->Subloops, MainLoop, + LoopDistributionTransform::FollowupAll, Stage); + NodeTy *Successor = Trans->isLegacy() ? All : nullptr; + inheritLoopAttributes(All, MainLoop, true, Successor == All); + + MainLoop->applyTransformation(Trans, {All}, Successor); + Builder.applyDistribution(Trans, MainLoop); + return Successor; + } + + NodeTy * + applyVectorizeInterleave(LoopVectorizationInterleavingTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + int Stage = Trans->getLoopPipelineStage(); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupAll, Stage); + NodeTy *Vectorized = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupVectorized, Stage); + NodeTy *Epilogue = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupEpilogue, Stage); + + NodeTy *Successor = Trans->isLegacy() ? All : Vectorized; + if (Trans->isAssumeSafety() && !MainLoop->IsParallel) { + MainLoop->IsParallel = true; + Builder.markParallel(MainLoop); + } + + inheritLoopAttributes(All, MainLoop, true, All == Successor); + MainLoop->applyTransformation(Trans, {All, Vectorized, Epilogue}, + Successor); + Builder.applyVectorizeInterleave(Trans, MainLoop); + return Successor; + } + + NodeTy *applyVectorize(LoopVectorizationTransform *Trans, + NodeTy *MainLoop) { + auto *VecInterleaveTrans = LoopVectorizationInterleavingTransform::create( + Trans->getLoc(), false, false, true, false, Trans->getWidth(), + Trans->isPredicateEnabled(), 1); + return applyVectorizeInterleave(VecInterleaveTrans, MainLoop); + } + + NodeTy *applyInterleave(LoopInterleavingTransform *Trans, + NodeTy *MainLoop) { + auto *VecInterleaveTrans = LoopVectorizationInterleavingTransform::create( + Trans->getLoc(), false, false, false, true, 1, None, + Trans->getInterleaveCount()); + return applyVectorizeInterleave(VecInterleaveTrans, MainLoop); + } + + NodeTy *applyPipelining(LoopPipeliningTransform *Trans, NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + MainLoop->applyTransformation(Trans, {}, nullptr); + Builder.applyPipelining(Trans, MainLoop); + return nullptr; + } + + NodeTy *applyAssumeParallel(LoopAssumeParallelTransform *Assumption, + NodeTy *MainLoop) { + if (MainLoop->isParallel()) + return MainLoop; + + MainLoop->IsParallel = true; + Builder.markParallel(MainLoop); + + return MainLoop; + } + + void applyOne(NodeTy *L, NodeTransform *NT) { + applyTransform(NT->Trans, L); + } + + void applyOne(NodeTy *L, const TransformExecutableDirective *D) { + Transform *Trans = D->getTransform(); + assert(Trans); + applyTransform(Trans, L); + } + + void traverseSubloops(NodeTy *L) { + // Transform subloops first. + for (NodeTy *SubL : L->getSubLoops()) { + SubL = SubL->getLatestSuccessor(); + if (!SubL) + continue; + traverse(SubL); + } + } + + bool applyTransform(NodeTy *L) { + if (L->isRoot()) + return false; + assert(L == L->getLatestSuccessor() && + "Loop must not have been consumed by another transformation"); + + Stmt *OrigStmt = L->getInheritedOriginal(); + auto TransformsOnStmt = TransByStmt.find(OrigStmt); + if (TransformsOnStmt != TransByStmt.end()) { + auto &List = TransformsOnStmt->second; + if (!List.empty()) { + NodeTransform *Trans = List.front(); + applyOne(L, Trans); + List.erase(List.begin()); + return true; + } + } + + return false; + } + + void traverse(NodeTy *L) { + do { + L = L->getLatestSuccessor(); + if (!L) + break; + traverseSubloops(L); + } while (applyTransform(L)); + } + }; + +protected: + TransformedTreeBuilder(ASTContext &ASTCtx, + llvm::SmallVectorImpl &AllNodes) + : ASTCtx(ASTCtx), AllNodes(AllNodes) {} + + NodeTy *createRoot(llvm::ArrayRef SubLoops) { + auto *Result = new NodeTy(SubLoops, nullptr, nullptr, -1, -1); + AllNodes.push_back(Result); + Result->IsRoot = true; + return Result; + } + + NodeTy *createPhysical(llvm::ArrayRef SubLoops, + clang::Stmt *Original) { + assert(Original); + auto *Result = new NodeTy(SubLoops, nullptr, Original, -1, 0); + AllNodes.push_back(Result); + return Result; + } + + NodeTy *createFollowup(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + int FollowupRole, int Stage) { + assert(BasedOn); + auto *Result = new NodeTy(SubLoops, BasedOn, nullptr, FollowupRole, Stage); + AllNodes.push_back(Result); + return Result; + } + +public: + void markParallel(NodeTy *L) { L->markParallel(); } + + NodeTy * + computeTransformedStructure(Stmt *Body, + llvm::DenseMap &StmtToTree) { + if (!Body) + return nullptr; + + // Create original tree. + SmallVector TopLevelLoops; + buildPhysicalLoopTree(Body, TopLevelLoops, StmtToTree); + NodeTy *Root = getDerived().createRoot(TopLevelLoops); + + // Collect all loop transformations. + CollectTransformationsVisitor Collector(getDerived(), StmtToTree); + Collector.TraverseStmt(Body); + auto &TransformList = Collector.Transforms; + + // Local function to apply every transformation that the predicate returns + // true to. This is to emulate LLVM's pass pipeline that would apply + // transformation in this order. + auto SelectiveApplicator = [this, &TransformList, Root](auto Pred) { + TransformApplicator OMPApplicator(getDerived()); + bool AnyActive = false; + for (NodeTransform &NT : TransformList) { + if (Pred(NT)) { + OMPApplicator.addNodeTransform(&NT); + AnyActive = true; + } + } + + // No traversal needed if no transformations to apply. + if (!AnyActive) + return; + + OMPApplicator.traverse(Root); + + // Report leftover transformations. + for (auto &P : OMPApplicator.TransByStmt) { + for (NodeTransform *NT : P.second) { + if (!NT->Trans->isLegacy()) + getDerived().Diag(NT->Trans->getBeginLoc(), + diag::err_sema_transform_missing_loop); + } + } + + // Remove applied transformations from list. + auto NewEnd = + std::remove_if(TransformList.begin(), TransformList.end(), Pred); + TransformList.erase(NewEnd, TransformList.end()); + }; + + // Apply full unrolling, loop distribution, vectorization/interleaving. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy() && Unroll->isFull(); + if (auto Dist = dyn_cast(NT.Trans)) + return Dist->isLegacy(); + if (auto Vec = dyn_cast(NT.Trans)) + return Vec->isLegacy(); + return false; + }); + + // Apply unrollandjam. + // While all transformations in TransformList have already been added in the + // order in the pass pipeline (relative to transformations on the same + // loop), transformations on unrollandjam's inner loop are not ordered + // relative to the outer loop. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy(); + return false; + }); + + // Apply partial unrolling and software pipelining. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy(); + if (isa(NT.Trans)) + return true; + return false; + }); + + // Apply all others. + SelectiveApplicator([](NodeTransform &NT) -> bool { return true; }); + assert(TransformList.size() == 0 && "Must apply all transformations"); + + return Root; + } +}; + +} // namespace clang +#endif /* LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H */ Index: include/clang/Basic/DiagnosticSemaKinds.td =================================================================== --- include/clang/Basic/DiagnosticSemaKinds.td +++ include/clang/Basic/DiagnosticSemaKinds.td @@ -9994,4 +9994,31 @@ "__builtin_bit_cast %select{source|destination}0 type must be trivially copyable">; def err_bit_cast_type_size_mismatch : Error< "__builtin_bit_cast source size does not equal destination size (%0 vs %1)">; + +// Pragma transform support. +def err_sema_transform_unrollandjam_expect_nested_loop : Error< + "unroll-and-jam requires exactly one nested loop">; +def err_sema_transform_unrollandjam_not_innermost : Error< + "inner loop of unroll-and-jam is not the innermost">; +def err_sema_transform_missing_loop : Error< + "transformation did not find its loop to transform; it might have been consumed by another transformations">; +def err_sema_transform_unroll_full_or_partial : Error< + "unroll full/partial clauses specified multiple times">; +def err_sema_transform_unroll_factor_expect_int : Error< + "unroll partial clause expects an int">; +def err_sema_transform_unroll_partial_once : Error< + "unroll-and-jam partial clause can only be used once">; +def err_sema_transform_vectorize_width_once : Error< + "vectorize width clause can only be used once">; +def err_sema_transform_interleave_factor_once : Error< + "interleave factor clause can only be used once">; +def err_sema_transform_expected_loop : Error< + "cannot find loop to transform">; +def err_sema_transform_unroll_partial_ge_two : Error< + "unroll factor must be at least two">; + def err_sema_transform_interleave_factor_ge_two : Error< + "interleave factor must be at least two">; +def warn_sema_transform_pass_order : Warning< + "the LLVM pass structure currently is not able to apply the transformations in this order">, InGroup; + } // end of sema component. Index: include/clang/Basic/StmtNodes.td =================================================================== --- include/clang/Basic/StmtNodes.td +++ include/clang/Basic/StmtNodes.td @@ -210,6 +210,9 @@ // OpenCL Extensions. def AsTypeExpr : DStmt; +// Transform Directives. +def TransformExecutableDirective : Stmt; + // OpenMP Directives. def OMPExecutableDirective : Stmt<1>; def OMPLoopDirective : DStmt; Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -11549,6 +11549,8 @@ TransformClause *ActOnPartialClause(SourceRange Loc, Expr *Factor); TransformClause *ActOnWidthClause(SourceRange Loc, Expr *Width); TransformClause *ActOnFactorClause(SourceRange Loc, Expr *Factor); + + void HandleLoopTransformations(FunctionDecl *FD); }; /// RAII object that enters a new expression evaluation context. Index: include/clang/Sema/SemaTransform.h =================================================================== --- /dev/null +++ include/clang/Sema/SemaTransform.h @@ -0,0 +1,78 @@ +//===---- SemaTransform.h ------------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Semantic analysis for code transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_SEMATRANSFORM_H +#define LLVM_CLANG_SEMA_SEMATRANSFORM_H + +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Sema/Sema.h" + +namespace clang { +class Sema; + +class SemaTransformedTree : public TransformedTree { + friend class TransformedTree; + using BaseTy = TransformedTree; + using NodeTy = SemaTransformedTree; + + BaseTy &getBase() { return *this; } + const BaseTy &getBase() const { return *this; } + +public: + SemaTransformedTree(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + clang::Stmt *Original, int FollowupRole, int Stage) + : TransformedTree(SubLoops, BasedOn, Original, FollowupRole, Stage) {} +}; + +class SemaTransformedTreeBuilder + : public TransformedTreeBuilder { + using NodeTy = SemaTransformedTree; + + Sema &Sem; + +public: + SemaTransformedTreeBuilder(ASTContext &ASTCtx, + llvm::SmallVectorImpl &AllNodes, + Sema &Sem) + : TransformedTreeBuilder(ASTCtx, AllNodes), Sem(Sem) {} + + auto Diag(SourceLocation Loc, unsigned DiagID) { + return Sem.Diag(Loc, DiagID); + } + + void applyOriginal(SemaTransformedTree *L) {} + + void disableDistribution(SemaTransformedTree *L) {} + void disableVectorizeInterleave(SemaTransformedTree *L) {} + void disableUnrollAndJam(SemaTransformedTree *L) {} + void disableUnroll(SemaTransformedTree *L) {} + void disablePipelining(SemaTransformedTree *L) {} + + void applyDistribution(LoopDistributionTransform *Trans, + SemaTransformedTree *InputLoop) {} + void applyVectorizeInterleave(LoopVectorizationInterleavingTransform *Trans, + SemaTransformedTree *MainLoop) {} + void applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + SemaTransformedTree *OuterLoop, + SemaTransformedTree *InnerLoop) {} + void applyUnroll(LoopUnrollingTransform *Trans, + SemaTransformedTree *OriginalLoop) {} + void applyPipelining(LoopPipeliningTransform *Trans, + SemaTransformedTree *MainLoop) {} + + void inheritLoopAttributes(SemaTransformedTree *Dst, SemaTransformedTree *Src, + bool IsAll, bool IsSuccessor) {} +}; + +} // namespace clang +#endif /* LLVM_CLANG_SEMA_SEMATRANSFORM_H */ Index: lib/AST/ASTTypeTraits.cpp =================================================================== --- lib/AST/ASTTypeTraits.cpp +++ lib/AST/ASTTypeTraits.cpp @@ -41,6 +41,13 @@ {NKI_None, "OMPClause"}, #define OPENMP_CLAUSE(TextualSpelling, Class) {NKI_OMPClause, #Class}, #include "clang/Basic/OpenMPKinds.def" + {NKI_TransformClause, "TransformClause"}, +#define TRANSFORM_CLAUSE(Keyword, Name) {NKI_##Name##Clause, #Name "Clause"}, +#include "clang/AST/TransformKinds.def" + {NKI_Transform, "Transform"}, +#define TRANSFORM_DIRECTIVE(Keyworld, Name) \ + {NKI_##Name##Transform, #Name "Transform"}, +#include "clang/AST/TransformKinds.def" }; bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const { @@ -124,6 +131,30 @@ llvm_unreachable("invalid stmt kind"); } +ASTNodeKind ASTNodeKind::getFromNode(const Transform &T) { + switch (T.getKind()) { +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + case Transform::Kind::Name##Kind: \ + return ASTNodeKind(NKI_##Transform); +#include "clang/AST/TransformKinds.def" + case Transform::Kind::UnknownKind: + llvm_unreachable("unexpected transform kind"); + } + llvm_unreachable("invalid transform kind"); +} + +ASTNodeKind ASTNodeKind::getFromNode(const TransformClause &C) { + switch (C.getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind ::Name##Kind: \ + return ASTNodeKind(NKI_##Name##Clause); +#include "clang/AST/TransformKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("unexpected transform kind"); + } + llvm_unreachable("invalid transform kind"); +} + void DynTypedNode::print(llvm::raw_ostream &OS, const PrintingPolicy &PP) const { if (const TemplateArgument *TA = get()) Index: lib/AST/JSONNodeDumper.cpp =================================================================== --- lib/AST/JSONNodeDumper.cpp +++ lib/AST/JSONNodeDumper.cpp @@ -167,6 +167,10 @@ void JSONNodeDumper::Visit(const OMPClause *C) {} +void JSONNodeDumper::Visit(const TransformClause *C) {} + +void JSONNodeDumper::Visit(const Transform *T) {} + void JSONNodeDumper::Visit(const BlockDecl::Capture &C) { JOS.attribute("kind", "Capture"); attributeOnlyIfTrue("byref", C.isByRef()); Index: lib/AST/StmtPrinter.cpp =================================================================== --- lib/AST/StmtPrinter.cpp +++ lib/AST/StmtPrinter.cpp @@ -30,6 +30,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" @@ -635,6 +636,19 @@ if (Policy.IncludeNewlines) OS << NL; } +void StmtPrinter::VisitTransformExecutableDirective( + TransformExecutableDirective *Node) { + Indent() << "#pragma clang transform " + << Transform ::getTransformDirectiveKeyword( + Node->getTransformKind()); + for (TransformClause *Clause : Node->clauses()) { + OS << ' '; + Clause->print(OS, Policy); + } + OS << NL; + PrintStmt(Node->getAssociated()); +} + //===----------------------------------------------------------------------===// // OpenMP directives printing methods //===----------------------------------------------------------------------===// Index: lib/AST/StmtProfile.cpp =================================================================== --- lib/AST/StmtProfile.cpp +++ lib/AST/StmtProfile.cpp @@ -771,6 +771,11 @@ } } +void StmtProfiler::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { + VisitStmt(S); +} + void StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) { VisitStmt(S); Index: lib/AST/StmtTransform.cpp =================================================================== --- lib/AST/StmtTransform.cpp +++ lib/AST/StmtTransform.cpp @@ -17,6 +17,32 @@ using namespace clang; +TransformExecutableDirective *TransformExecutableDirective::create( + ASTContext &Ctx, SourceRange Range, Stmt *Associated, Transform *Trans, + ArrayRef Clauses, Transform::Kind TransKind) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(Clauses.size())); + return new (Mem) TransformExecutableDirective(Range, Associated, Trans, + Clauses, TransKind); +} + +TransformExecutableDirective * +TransformExecutableDirective::createEmpty(ASTContext &Ctx, + unsigned NumClauses) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(NumClauses)); + return new (Mem) TransformExecutableDirective(NumClauses); +} + +llvm::StringRef TransformClause::getClauseName(Kind K) { + assert(K >= UnknownKind); + assert(K <= LastKind); + const char *Names[LastKind + 1] = { + "Unknown", +#define TRANSFORM_CLAUSE(Keyword, Name) #Name, +#include "clang/AST/TransformKinds.def" + }; + return Names[K]; +} + bool TransformClause::isValidForTransform(Transform::Kind TransformKind, TransformClause::Kind ClauseKind) { switch (TransformKind) { @@ -44,6 +70,56 @@ return TransformClause::UnknownKind; } +TransformClause ::child_range TransformClause ::children() { + switch (getKind()) { + case UnknownKind: + llvm_unreachable("Unknown child"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return static_cast(this)->children(); +#include "clang/AST/TransformKinds.def" + } + llvm_unreachable("Unhandled clause kind"); +} + +void TransformClause ::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static decltype(&TransformClause::print) PrintFuncs[LastKind] = { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + static_cast(&Name##Clause ::print), +#include "clang/AST/TransformKinds.def" + }; + (this->*PrintFuncs[getKind() - 1])(OS, Policy); +} + +void FullClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "full"; +} + +void PartialClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "partial("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void WidthClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "width("; + Width->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void FactorClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "factor("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + const Stmt *clang::getAssociatedLoop(const Stmt *S) { switch (S->getStmtClass()) { case Stmt::ForStmtClass: @@ -55,6 +131,9 @@ return getAssociatedLoop(cast(S)->getCapturedStmt()); case Stmt::AttributedStmtClass: return getAssociatedLoop(cast(S)->getSubStmt()); + case Stmt::TransformExecutableDirectiveClass: + return getAssociatedLoop( + cast(S)->getAssociated()); default: if (auto LD = dyn_cast(S)) return getAssociatedLoop(LD->getAssociatedStmt()); Index: lib/AST/TextNodeDumper.cpp =================================================================== --- lib/AST/TextNodeDumper.cpp +++ lib/AST/TextNodeDumper.cpp @@ -320,6 +320,43 @@ OS << " "; } +void TextNodeDumper::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { + if (S) + AddChild([=] { Visit(S->getTransform()); }); +} + +void TextNodeDumper::Visit(const TransformClause *C) { + if (!C) { + ColorScope Color(OS, ShowColors, NullColor); + OS << "<<>> TransformClause"; + return; + } + { + ColorScope Color(OS, ShowColors, AttrColor); + StringRef ClauseName = TransformClause::getClauseName(C->getKind()); + OS << ClauseName << "Clause"; + } + dumpPointer(C); + dumpSourceRange(C->getLoc()); +} + +void TextNodeDumper::Visit(const Transform *T) { + if (!T) { + ColorScope Color(OS, ShowColors, NullColor); + OS << "<<>> Transform"; + return; + } + { + ColorScope Color(OS, ShowColors, AttrColor); + StringRef TransformName = + Transform::getTransformDirectiveName(T->getKind()); + OS << TransformName << "Transform"; + } + dumpPointer(T); + dumpSourceRange(T->getLoc()); +} + void TextNodeDumper::Visit(const GenericSelectionExpr::ConstAssociation &A) { const TypeSourceInfo *TSI = A.getTypeSourceInfo(); if (TSI) { Index: lib/AST/Transform.cpp =================================================================== --- lib/AST/Transform.cpp +++ lib/AST/Transform.cpp @@ -46,3 +46,46 @@ }; return Keywords[K]; } + +int Transform::getLoopPipelineStage() const { + switch (getKind()) { + case Transform::Kind::LoopDistributionKind: + return 1; + case Transform::Kind::LoopUnrollingKind: + return cast(this)->isFull() ? 0 : 4; + case Transform::Kind::LoopInterleavingKind: + case Transform::Kind::LoopVectorizationKind: + case Transform::Kind::LoopVectorizationInterleavingKind: + return 2; + case Transform::Kind::LoopUnrollAndJamKind: + return 3; + default: + return -1; + } +} + +int Transform::getNumInputs() const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static const decltype( + &Transform::getNumInputs) GetNumInputFuncs[Transform::Kind::LastKind] = { +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + static_cast( \ + &Name##Transform ::getNumInputs), +#include "clang/AST/TransformKinds.def" + }; + return (this->*GetNumInputFuncs[getKind() - 1])(); +} + +int Transform::getNumFollowups() const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static const decltype(&Transform::getNumInputs) + GetNumFollowupFuncs[Transform::Kind::LastKind] = { +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + static_cast( \ + &Name##Transform ::getNumFollowups), +#include "clang/AST/TransformKinds.def" + }; + return (this->*GetNumFollowupFuncs[getKind() - 1])(); +} Index: lib/CodeGen/CGStmt.cpp =================================================================== --- lib/CodeGen/CGStmt.cpp +++ lib/CodeGen/CGStmt.cpp @@ -346,6 +346,9 @@ EmitOMPTargetTeamsDistributeSimdDirective( cast(*S)); break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); + break; } } Index: lib/Sema/SemaDecl.cpp =================================================================== --- lib/Sema/SemaDecl.cpp +++ lib/Sema/SemaDecl.cpp @@ -14101,6 +14101,8 @@ "Leftover expressions for odr-use checking"); } + HandleLoopTransformations(FD); + if (!IsInstantiation) PopDeclContext(); Index: lib/Sema/SemaTransform.cpp =================================================================== --- lib/Sema/SemaTransform.cpp +++ lib/Sema/SemaTransform.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "clang/Sema/SemaTransform.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtTransform.h" #include "clang/AST/Transform.h" @@ -29,34 +30,223 @@ Sema::ActOnTransform(Transform::Kind Kind, llvm::ArrayRef Clauses, SourceRange Loc) { - // TOOD: implement - return TransformError(); + SmallVector ClauseByKind[TransformClause::LastKind + 1]; + for (TransformClause *Clause : Clauses) { + ClauseByKind[Clause->getKind()].push_back(Clause); + } + + switch (Kind) { + case Transform::LoopUnrollingKind: { + FullClause *Full = nullptr; + PartialClause *Partial = nullptr; + for (TransformClause *Clause : Clauses) { + switch (Clause->getKind()) { + case TransformClause::FullKind: + if (Full || Partial) { + Diag(Clause->getBeginLoc(), + diag::err_sema_transform_unroll_full_or_partial); + continue; + } + Full = cast(Clause); + break; + case TransformClause::PartialKind: + if (Full || Partial) { + Diag(Clause->getBeginLoc(), + diag::err_sema_transform_unroll_full_or_partial); + continue; + } + Partial = cast(Clause); + break; + default: + llvm_unreachable("Clause not supported by unroll"); + } + } + + assert(!Full || !Partial); + + if (Full) + return LoopUnrollingTransform::createFull(Loc, false, true, true); + else if (Partial) { + Expr *FactorExpr = Partial->getFactor(); + if (isTemplateDependent(FactorExpr)) + return {}; + Expr::EvalResult FactorRes; + FactorExpr->EvaluateAsInt(FactorRes, Context); + if (!FactorRes.Val.isInt()) + return TransformError( + Diag(FactorExpr->getExprLoc(), + diag::err_sema_transform_unroll_factor_expect_int)); + llvm::APSInt FactorVal = FactorRes.Val.getInt(); + int64_t FactorInt = FactorVal.getSExtValue(); + return LoopUnrollingTransform::createPartial(Loc, false, true, true, + FactorInt); + } + + return LoopUnrollingTransform::create(Loc, false, true, true); + } + + case Transform::LoopUnrollAndJamKind: { + PartialClause *Partial = nullptr; + for (TransformClause *Clause : Clauses) { + switch (Clause->getKind()) { + case TransformClause::PartialKind: + if (Partial) { + Diag(Clause->getBeginLoc(), + diag::err_sema_transform_unroll_partial_once); + continue; + } + Partial = cast(Clause); + break; + default: + llvm_unreachable("Clause not supported by unroll-and-jam"); + } + } + + if (Partial) { + Expr *FactorExpr = Partial->getFactor(); + if (isTemplateDependent(FactorExpr)) + return {}; + Expr::EvalResult FactorRes; + FactorExpr->EvaluateAsInt(FactorRes, Context); + llvm::APSInt FactorVal = FactorRes.Val.getInt(); + int64_t FactorInt = FactorVal.getSExtValue(); + return LoopUnrollAndJamTransform::createPartial(Loc, false, true, + FactorInt); + } + + return LoopUnrollAndJamTransform::create(Loc, false, true); + } + + case clang::Transform::LoopDistributionKind: + assert(Clauses.size() == 0 && "distribute has no clauses"); + return LoopDistributionTransform::create(Loc, false); + + case clang::Transform::LoopVectorizationKind: { + WidthClause *Width = nullptr; + for (TransformClause *Clause : Clauses) { + switch (Clause->getKind()) { + case TransformClause::WidthKind: + if (Width) { + Diag(Clause->getBeginLoc(), + diag::err_sema_transform_vectorize_width_once); + continue; + } + Width = cast(Clause); + break; + + default: + llvm_unreachable("Clause not supported by distribute"); + } + } + + int64_t WidthInt = 0; + if (Width) { + if (isTemplateDependent(Width->getWidth())) + return {}; + + Expr::EvalResult WidthRes; + Width->getWidth()->EvaluateAsInt(WidthRes, Context); + llvm::APSInt WidthVal = WidthRes.Val.getInt(); + WidthInt = WidthVal.getSExtValue(); + } + + return LoopVectorizationTransform::Create(Loc, true, WidthInt, None); + } + + case clang::Transform::LoopInterleavingKind: { + FactorClause *Factor = nullptr; + + for (TransformClause *Clause : Clauses) { + switch (Clause->getKind()) { + case TransformClause::FactorKind: + if (Factor) { + Diag(Clause->getBeginLoc(), + diag::err_sema_transform_interleave_factor_once); + continue; + } + Factor = cast(Clause); + break; + default: + llvm_unreachable("Clause not supported by distribute"); + } + } + + int64_t FactorInt = 0; + if (Factor) { + if (isTemplateDependent(Factor->getFactor())) + return {}; + + Expr::EvalResult FactorRes; + Factor->getFactor()->EvaluateAsInt(FactorRes, Context); + llvm::APSInt FactorVal = FactorRes.Val.getInt(); + FactorInt = FactorVal.getSExtValue(); + } + + return LoopInterleavingTransform::Create(Loc, true, FactorInt); + } + default: + llvm_unreachable("unimplemented"); + } } StmtResult Sema::ActOnLoopTransformDirective(Transform::Kind Kind, Transform *Trans, llvm::ArrayRef Clauses, Stmt *AStmt, SourceRange Loc) { - // TOOD: implement - return StmtError(); + const Stmt *Loop = getAssociatedLoop(AStmt); + if (!Loop) + return StmtError( + Diag(Loc.getBegin(), diag::err_sema_transform_expected_loop)); + + if (!Trans) { + TransformResult Transform = ActOnTransform(Kind, Clauses, Loc); + Trans = Transform.isUsable() ? Transform.get() : nullptr; + } + + return TransformExecutableDirective::create(Context, Loc, AStmt, Trans, + Clauses, Kind); } TransformClause *Sema::ActOnFullClause(SourceRange Loc) { - // TOOD: implement - return nullptr; + return FullClause::create(Context, Loc); } TransformClause *Sema::ActOnPartialClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + if (!Factor->isValueDependent() && Factor->isEvaluatable(getASTContext())) { + llvm::APSInt ValueAPS = Factor->EvaluateKnownConstInt(getASTContext()); + if (!ValueAPS.sge(2)) { + Diag(Loc.getBegin(), diag::err_sema_transform_unroll_partial_ge_two); + return nullptr; + } + } + return PartialClause::create(Context, Loc, Factor); } TransformClause *Sema::ActOnWidthClause(SourceRange Loc, Expr *Width) { - // TOOD: implement - return nullptr; + return WidthClause::create(Context, Loc, Width); } TransformClause *Sema::ActOnFactorClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + if (!Factor->isValueDependent() && Factor->isEvaluatable(getASTContext())) { + llvm::APSInt ValueAPS = Factor->EvaluateKnownConstInt(getASTContext()); + if (!ValueAPS.sge(2)) { + Diag(Loc.getBegin(), diag::err_sema_transform_interleave_factor_ge_two); + return nullptr; + } + } + return FactorClause::create(Context, Loc, Factor); +} + +void Sema::HandleLoopTransformations(FunctionDecl *FD) { + if (!FD || FD->isInvalidDecl()) + return; + + // Note: this is called on a template-code and the instantiated code. + llvm::DenseMap StmtToTree; + llvm::SmallVector AllNodes; + SemaTransformedTreeBuilder Builder(getASTContext(), AllNodes, *this); + Builder.computeTransformedStructure(FD->getBody(), StmtToTree); + + for (auto N : AllNodes) + delete N; } Index: lib/Sema/TreeTransform.h =================================================================== --- lib/Sema/TreeTransform.h +++ lib/Sema/TreeTransform.h @@ -26,6 +26,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/Sema/Designator.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Ownership.h" @@ -341,6 +342,51 @@ /// \returns the transformed statement. StmtResult TransformStmt(Stmt *S, StmtDiscardKind SDK = SDK_Discarded); + Transform *TransformTransform(Transform *T) { return T; } + + TransformClause *TransformTransformClause(TransformClause *S); + + TransformClause *TransformFullClause(FullClause *C) { + return RebuildFullClause(C->getLoc()); + } + + TransformClause *RebuildFullClause(SourceRange Loc) { + return getSema().ActOnFullClause(Loc); + } + + TransformClause *TransformPartialClause(PartialClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildPartialClause(C->getLoc(), E.get()); + } + + TransformClause *RebuildPartialClause(SourceRange Loc, Expr *Factor) { + return getSema().ActOnPartialClause(Loc, Factor); + } + + TransformClause *TransformWidthClause(WidthClause *C) { + ExprResult E = getDerived().TransformExpr(C->getWidth()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildWidthClause(C->getLoc(), E.get()); + } + + TransformClause *RebuildWidthClause(SourceRange Loc, Expr *Width) { + return getSema().ActOnWidthClause(Loc, Width); + } + + TransformClause *TransformFactorClause(FactorClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildFactorClause(C->getLoc(), E.get()); + } + + TransformClause *RebuildFactorClause(SourceRange Loc, Expr *Factor) { + return getSema().ActOnFactorClause(Loc, Factor); + } + /// Transform the given statement. /// /// By default, this routine transforms a statement by delegating to the @@ -1503,6 +1549,17 @@ return getSema().BuildObjCAtThrowStmt(AtLoc, Operand); } + StmtResult + RebuildTransformExecutableDirective(Transform::Kind Kind, Transform *Trans, + llvm::ArrayRef Clauses, + Stmt *AStmt, SourceRange Loc) { + StmtResult Result = + getSema().ActOnLoopTransformDirective(Kind, Trans, Clauses, AStmt, Loc); + assert(!Result.isUsable() || + isa(Result.get())); + return Result; + } + /// Build a new OpenMP executable directive. /// /// By default, performs semantic analysis to build the new statement. @@ -7853,6 +7910,43 @@ return S; } +template +TransformClause * +TreeTransform::TransformTransformClause(TransformClause *S) { + if (!S) + return S; + + switch (S->getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return getDerived().Transform##Name##Clause(cast(S)); +#include "clang/AST/TransformKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Should not be unknown"); + } + + return S; +} + +template +StmtResult TreeTransform::TransformTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm::SmallVector TClauses; + for (TransformClause *C : D->clauses()) { + TransformClause *TClause = getDerived().TransformTransformClause(C); + TClauses.push_back(TClause); + } + + Stmt *AStmt = D->getAssociated(); + StmtResult TBody = getDerived().TransformStmt(AStmt); + assert(TBody.isUsable()); + + Transform *Trans = getDerived().TransformTransform(D->getTransform()); + StmtResult TDirective = getDerived().RebuildTransformExecutableDirective( + D->getTransformKind(), Trans, TClauses, TBody.get(), D->getLoc()); + return TDirective; +} + //===----------------------------------------------------------------------===// // OpenMP directive transformation //===----------------------------------------------------------------------===// Index: lib/Serialization/ASTReaderStmt.cpp =================================================================== --- lib/Serialization/ASTReaderStmt.cpp +++ lib/Serialization/ASTReaderStmt.cpp @@ -2000,6 +2000,15 @@ E->SrcExpr = Record.readSubExpr(); } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtReader::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} + //===----------------------------------------------------------------------===// // OpenMP Directives. //===----------------------------------------------------------------------===// Index: lib/Serialization/ASTWriterStmt.cpp =================================================================== --- lib/Serialization/ASTWriterStmt.cpp +++ lib/Serialization/ASTWriterStmt.cpp @@ -1943,6 +1943,14 @@ Record.AddSourceLocation(S->getLeaveLoc()); Code = serialization::STMT_SEH_LEAVE; } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtWriter::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} //===----------------------------------------------------------------------===// // OpenMP Directives. Index: test/AST/ast-dump-transform-unroll.c =================================================================== --- /dev/null +++ test/AST/ast-dump-transform-unroll.c @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fopenmp -ast-dump %s | FileCheck %s + +void unroll_full(int n) { +#pragma clang transform unroll full + for (int i = 0; i < 4; i+=1) + ; +} + +// CHECK: FunctionDecl {{.*}} unroll_full +// CHECK: TransformExecutableDirective +// CHECK-NEXT: LoopUnrollingTransform +// CHECK-NEXT: FullClause +// CHECK-NEXT: ForStmt + + +void unroll_partial(int n) { +#pragma clang transform unroll partial(4) + for (int i = 0; i < n; i+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unroll_partial +// CHECK: TransformExecutableDirective +// CHECK-NEXT: LoopUnrollingTransform +// CHECK-NEXT: PartialClause +// CHECK-NEXT: IntegerLiteral +// CHECK-NEXT: ForStmt Index: test/AST/ast-print-pragma-transform-distribute.cpp =================================================================== --- /dev/null +++ test/AST/ast-print-pragma-transform-distribute.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -ast-print %s -o - | FileCheck %s + +void distribute(int *List, int Length) { +// CHECK: #pragma clang transform distribute +#pragma clang transform distribute +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} Index: test/AST/ast-print-pragma-transform-interleave.cpp =================================================================== --- /dev/null +++ test/AST/ast-print-pragma-transform-interleave.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform interleave factor(4) +#pragma clang transform interleave factor(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} Index: test/AST/ast-print-pragma-transform-unroll.cpp =================================================================== --- /dev/null +++ test/AST/ast-print-pragma-transform-unroll.cpp @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unroll partial(4) +#pragma clang transform unroll partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; + +// CHECK: #pragma clang transform unroll full +#pragma clang transform unroll full +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} Index: test/AST/ast-print-pragma-transform-unrollandjam.cpp =================================================================== --- /dev/null +++ test/AST/ast-print-pragma-transform-unrollandjam.cpp @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unrollandjam partial(4) +#pragma clang transform unrollandjam partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + for (int j = 0; j < Length; j += 1) + List[i] += j; +} Index: test/AST/ast-print-pragma-transform-vectorize.cpp =================================================================== --- /dev/null +++ test/AST/ast-print-pragma-transform-vectorize.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform vectorize width(4) +#pragma clang transform vectorize width(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} Index: test/SemaCXX/pragma-transform-interleave.cpp =================================================================== --- /dev/null +++ test/SemaCXX/pragma-transform-interleave.cpp @@ -0,0 +1,8 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify %s + +void interleave(int *List, int Length, int Value) { +/* expected-error@+1 {{interleave factor clause can only be used once}} */ +#pragma clang transform interleave factor(4) factor(4) + for (int i = 0; i < Length; i++) + List[i] = Value; +} Index: test/SemaCXX/pragma-transform-unroll.cpp =================================================================== --- /dev/null +++ test/SemaCXX/pragma-transform-unroll.cpp @@ -0,0 +1,36 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify %s + +void unroll(int *List, int Length, int Value) { +/* expected-error@+1 {{unroll factor must be at least two}} */ +#pragma clang transform unroll partial(0) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll full/partial clauses specified multiple times}} */ +#pragma clang transform unroll full full + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll full/partial clauses specified multiple times}} */ +#pragma clang transform unroll partial(4) partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll full/partial clauses specified multiple times}} */ +#pragma clang transform unroll full partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{transformation did not find its loop to transform}} */ +#pragma clang transform unroll +#pragma clang transform unroll full + for (int i = 0; i < 8; i++) + List[i] = Value; + +int f = 4; + /* expected-error@+1 {{unroll partial clause expects an int}} */ +#pragma clang transform unroll partial(f) + for (int i = 0; i < Length; i+=1) + List[i] = i; + +} Index: test/SemaCXX/pragma-transform-unrollandjam.cpp =================================================================== --- /dev/null +++ test/SemaCXX/pragma-transform-unrollandjam.cpp @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify %s + +void unrollandjam(int *List, int Length, int Value) { +/* expected-error@+1 {{unroll-and-jam partial clause can only be used once}} */ +#pragma clang transform unrollandjam partial(4) partial(4) + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + List[i] += j*Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) { + for (int j = 0; j < Length; j++) + List[i] += j*Value; + for (int j = 0; j < Length; j++) + List[i] += j*Value; + } + +/* expected-error@+1 {{inner loop of unroll-and-jam is not the innermost}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + for (int k = 0; k < Length; k++) + List[i] += k + j*Value; +} Index: test/SemaCXX/pragma-transform-vectorize.cpp =================================================================== --- /dev/null +++ test/SemaCXX/pragma-transform-vectorize.cpp @@ -0,0 +1,8 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify %s + +void vectorize(int *List, int Length, int Value) { +/* expected-error@+1 {{vectorize width clause can only be used once}} */ +#pragma clang transform vectorize width(4) width(4) + for (int i = 0; i < Length; i++) + List[i] = Value; +} Index: test/SemaCXX/pragma-transform-wrongorder.cpp =================================================================== --- /dev/null +++ test/SemaCXX/pragma-transform-wrongorder.cpp @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -std=c++11 -fsyntax-only -verify %s + +void wrongorder(int *List, int Length, int Value) { + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform distribute +#pragma clang transform vectorize + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform vectorize +#pragma clang transform unrollandjam + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + List[i] += j; + +} Index: tools/libclang/CXCursor.cpp =================================================================== --- tools/libclang/CXCursor.cpp +++ tools/libclang/CXCursor.cpp @@ -725,6 +725,9 @@ break; case Stmt::BuiltinBitCastExprClass: K = CXCursor_BuiltinBitCastExpr; + break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); } CXCursor C = { K, 0, { Parent, S, TU } };