diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -49,6 +49,7 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); }; */ template @@ -205,6 +206,14 @@ }); } + void Visit(const TransformClause *C) { + getNodeDelegate().AddChild([=] { + getNodeDelegate().Visit(C); + for (const auto *S : C->children()) + Visit(S); + }); + } + void Visit(const ast_type_traits::DynTypedNode &N) { // FIXME: Improve this with a switch or a visitor pattern. if (const auto *D = N.get()) @@ -221,6 +230,8 @@ Visit(C); else if (const auto *T = N.get()) Visit(*T); + else if (const auto *C = N.get()) + Visit(C); } void dumpDeclContext(const DeclContext *DC) { @@ -603,6 +614,12 @@ Visit(C); } + void + VisitTransformExecutableDirective(const TransformExecutableDirective *Node) { + for (const auto *C : Node->clauses()) + Visit(C); + } + void VisitInitListExpr(const InitListExpr *ILE) { if (auto *Filler = ILE->getArrayFiller()) { Visit(Filler, "array_filler"); diff --git a/clang/include/clang/AST/ASTTypeTraits.h b/clang/include/clang/AST/ASTTypeTraits.h --- a/clang/include/clang/AST/ASTTypeTraits.h +++ b/clang/include/clang/AST/ASTTypeTraits.h @@ -20,6 +20,7 @@ #include "clang/AST/NestedNameSpecifier.h" #include "clang/AST/OpenMPClause.h" #include "clang/AST/Stmt.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/LLVM.h" @@ -71,6 +72,7 @@ static ASTNodeKind getFromNode(const Stmt &S); static ASTNodeKind getFromNode(const Type &T); static ASTNodeKind getFromNode(const OMPClause &C); + static ASTNodeKind getFromNode(const TransformClause &C); /// \} /// Returns \c true if \c this and \c Other represent the same kind. @@ -152,6 +154,9 @@ NKI_OMPClause, #define OPENMP_CLAUSE(TextualSpelling, Class) NKI_##Class, #include "clang/Basic/OpenMPKinds.def" + NKI_TransformClause, +#define TRANSFORM_CLAUSE(Keyword, Name) NKI_##Name##Clause, +#include "clang/AST/TransformClauseKinds.def" NKI_NumberOfKinds }; @@ -208,6 +213,8 @@ #include "clang/AST/TypeNodes.inc" #define OPENMP_CLAUSE(TextualSpelling, Class) KIND_TO_KIND_ID(Class) #include "clang/Basic/OpenMPKinds.def" +#define TRANSFORM_CLAUSE(Keyword, Name) KIND_TO_KIND_ID(Name##Clause) +#include "clang/AST/TransformClauseKinds.def" #undef KIND_TO_KIND_ID inline raw_ostream &operator<<(raw_ostream &OS, ASTNodeKind K) { diff --git a/clang/include/clang/AST/JSONNodeDumper.h b/clang/include/clang/AST/JSONNodeDumper.h --- a/clang/include/clang/AST/JSONNodeDumper.h +++ b/clang/include/clang/AST/JSONNodeDumper.h @@ -198,6 +198,8 @@ void Visit(const OMPClause *C); void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); + void Visit(const TransformClause *C); + void Visit(const Transform *T); void VisitTypedefType(const TypedefType *TT); void VisitFunctionType(const FunctionType *T); diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -33,6 +33,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/TemplateName.h" #include "clang/AST/Type.h" @@ -536,6 +537,14 @@ bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); + bool TraverseTransformClause(TransformClause *C); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + bool Visit##Name##Clause(Name##Clause *C); +#include "clang/AST/TransformClauseKinds.def" + + + + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue); bool PostVisitStmt(Stmt *S); }; @@ -2691,6 +2700,11 @@ // Traverse OpenCL: AsType, Convert. DEF_TRAVERSE_STMT(AsTypeExpr, {}) +DEF_TRAVERSE_STMT(TransformExecutableDirective, { + for (auto *C : S->clauses()) + TRY_TO(TraverseTransformClause(C)); +}) + // OpenMP directives. template bool RecursiveASTVisitor::TraverseOMPExecutableDirective( @@ -2862,6 +2876,24 @@ DEF_TRAVERSE_STMT(OMPTargetTeamsDistributeSimdDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) + + +template +bool RecursiveASTVisitor::TraverseTransformClause(TransformClause *C) { + if (!C) + return true; + switch (C->getKind()) { + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Cannot process unknown clause"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + TRY_TO(Visit##Name##Clause(static_cast(C))); \ + break; +#include "clang/AST/TransformClauseKinds.def" + } + return true; +} + // OpenMP clauses. template bool RecursiveASTVisitor::TraverseOMPClause(OMPClause *C) { @@ -3359,6 +3391,31 @@ return true; } +template +bool RecursiveASTVisitor::VisitFullClause(FullClause *C) { + return true; +} + +template +bool RecursiveASTVisitor::VisitPartialClause(PartialClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + +template +bool RecursiveASTVisitor::VisitWidthClause(WidthClause *C) { + TRY_TO(TraverseStmt(C->getWidth())); + return true; +} + +template +bool RecursiveASTVisitor::VisitFactorClause(FactorClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + + + // FIXME: look at the following tricky-seeming exprs to see if we // need to recurse on anything. These are ones that have methods // returning decls or qualtypes or nestednamespecifier -- though I'm diff --git a/clang/include/clang/AST/StmtTransform.h b/clang/include/clang/AST/StmtTransform.h --- a/clang/include/clang/AST/StmtTransform.h +++ b/clang/include/clang/AST/StmtTransform.h @@ -13,6 +13,7 @@ #ifndef LLVM_CLANG_AST_STMTTRANSFROM_H #define LLVM_CLANG_AST_STMTTRANSFROM_H +#include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "llvm/Support/raw_ostream.h" #include "clang/Basic/Transform.h" @@ -34,7 +35,170 @@ static Kind getClauseKind(Transform::Kind TransformKind, llvm::StringRef Str); static llvm::StringRef getClauseKeyword(TransformClause::Kind ClauseKind); - // TODO: implement +private: + Kind ClauseKind; + SourceRange LocRange; + +protected: + TransformClause(Kind K, SourceRange Range) : ClauseKind(K), LocRange(Range) {} + TransformClause(Kind K) : ClauseKind(K) {} + +public: + Kind getKind() const { return ClauseKind; } + + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setLoc(SourceRange L) { LocRange = L; } + void setLoc(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + + using child_iterator = Stmt::child_iterator; + using const_child_iterator = Stmt::const_child_iterator; + using child_range = Stmt::child_range; + using const_child_range = Stmt::const_child_range; + + child_range children(); + const_child_range children() const { + auto Children = const_cast(this)->children(); + return const_child_range(Children.begin(), Children.end()); + } + + static llvm::StringRef getClauseName(Kind K); + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class TransformClauseImpl : public TransformClause { +protected: + TransformClauseImpl(Kind ClauseKind, SourceRange Range) + : TransformClause(ClauseKind, Range) {} + explicit TransformClauseImpl(Kind ClauseKind) : TransformClause(ClauseKind) {} + +public: + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const { + llvm_unreachable("implement the print function"); + } + + child_range children() { + // By default, clauses have no children. + return child_range(child_iterator(), child_iterator()); + } +}; + +class FullClause final : public TransformClauseImpl { +private: + explicit FullClause(SourceRange Range) + : TransformClauseImpl(TransformClause::FullKind, Range) {} + FullClause() : TransformClauseImpl(TransformClause::FullKind) {} + +public: + static bool classof(const FullClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FullKind; + } + + static FullClause *create(ASTContext &Context, SourceRange Range) { + return new (Context) FullClause(Range); + } + static FullClause *createEmpty(ASTContext &Context) { + return new (Context) FullClause(); + } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class PartialClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + PartialClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::PartialKind, Range), + Factor(Factor) {} + PartialClause() : TransformClauseImpl(TransformClause::PartialKind) {} + +public: + static bool classof(const PartialClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == PartialKind; + } + + static PartialClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) PartialClause(Range, Factor); + } + static PartialClause *createEmpty(ASTContext &Context) { + return new (Context) PartialClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class WidthClause final : public TransformClauseImpl { +private: + Stmt *Width; + + WidthClause(SourceRange Range, Expr *Width) + : TransformClauseImpl(TransformClause::WidthKind, Range), Width(Width) {} + WidthClause() : TransformClauseImpl(TransformClause::WidthKind) {} + +public: + static bool classof(const WidthClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == WidthKind; + } + + static WidthClause *create(ASTContext &Context, SourceRange Range, + Expr *Width) { + return new (Context) WidthClause(Range, Width); + } + static WidthClause *createEmpty(ASTContext &Context) { + return new (Context) WidthClause(); + } + + child_range children() { return child_range(&Width, &Width + 1); } + + Expr *getWidth() const { return cast(Width); } + void setWidth(Expr *E) { Width = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; +}; + +class FactorClause final : public TransformClauseImpl { +private: + Stmt *Factor; + + FactorClause(SourceRange Range, Expr *Factor) + : TransformClauseImpl(TransformClause::FactorKind, Range), + Factor(Factor) {} + FactorClause() : TransformClauseImpl(TransformClause::FactorKind) {} + +public: + static bool classof(const FactorClause *T) { return true; } + static bool classof(const TransformClause *T) { + return T->getKind() == FactorKind; + } + + static FactorClause *create(ASTContext &Context, SourceRange Range, + Expr *Factor) { + return new (Context) FactorClause(Range, Factor); + } + static FactorClause *createEmpty(ASTContext &Context) { + return new (Context) FactorClause(); + } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + Expr *getFactor() const { return cast(Factor); } + void setFactor(Expr *E) { Factor = E; } + + void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const; }; /// Represents @@ -42,8 +206,104 @@ /// #pragma clang transform /// /// in the AST. -class TransformExecutableDirective final { - // TODO: implement +class TransformExecutableDirective final + : public Stmt, + private llvm::TrailingObjects { +public: + friend TransformClause; + friend TrailingObjects; + +private: + SourceRange LocRange; + Stmt *Associated = nullptr; + Transform::Kind TransKind = Transform::Kind::UnknownKind; + unsigned NumClauses; + +protected: + explicit TransformExecutableDirective(SourceRange LocRange, Stmt *Associated, + ArrayRef Clauses, + Transform::Kind TransKind) + : Stmt(Stmt::TransformExecutableDirectiveClass), LocRange(LocRange), + Associated(Associated), TransKind(TransKind), + NumClauses(Clauses.size()) { + setClauses(Clauses); + } + explicit TransformExecutableDirective(unsigned NumClauses) + : Stmt(Stmt::StmtClass::TransformExecutableDirectiveClass), + NumClauses(NumClauses) {} + + size_t numTrailingObjects(OverloadToken) const { + return NumClauses; + } + +public: + static bool classof(const TransformExecutableDirective *T) { return true; } + static bool classof(const Stmt *T) { + return T->getStmtClass() == Stmt::TransformExecutableDirectiveClass; + } + + static TransformExecutableDirective * + create(ASTContext &Ctx, SourceRange Range, Stmt *Associated, ArrayRef Clauses, Transform::Kind TransKind); + static TransformExecutableDirective *createEmpty(ASTContext &Ctx, + unsigned NumClauses); + + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setRange(SourceRange Loc) { LocRange = Loc; } + void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + + Stmt *getAssociated() const { return Associated; } + void setAssociated(Stmt *S) { Associated = S; } + + Transform::Kind getTransformKind() const { return TransKind; } + + child_range children() { return child_range(&Associated, &Associated + 1); } + const_child_range children() const { + return const_child_range(&Associated, &Associated + 1); + } + + unsigned getNumClauses() const { return NumClauses; } + MutableArrayRef clauses() { + return llvm::makeMutableArrayRef(getTrailingObjects(), + NumClauses); + } + ArrayRef clauses() const { + return llvm::makeArrayRef(getTrailingObjects(), + NumClauses); + } + + void setClauses(llvm::ArrayRef List) { + assert(List.size() == NumClauses); + for (auto p : llvm::zip_first(List, clauses())) + std::get<1>(p) = std::get<0>(p); + } + + auto getClausesOfKind(TransformClause::Kind Kind) const { + return llvm::make_filter_range(clauses(), [Kind](TransformClause *Clause) { + return Clause->getKind() == Kind; + }); + } + + template auto getClausesOfKind() const { + return llvm::map_range( + llvm::make_filter_range(clauses(), + [](TransformClause *Clause) { + return isa(Clause); + }), + [](TransformClause *Clause) { return cast(Clause); }); + } + + template + SpecificClause *getFirstClauseOfKind() const { + auto Range = getClausesOfKind(); + if (Range.begin() == Range.end()) + return nullptr; + return *Range.begin(); + } }; const Stmt *getAssociatedLoop(const Stmt *S); diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h --- a/clang/include/clang/AST/TextNodeDumper.h +++ b/clang/include/clang/AST/TextNodeDumper.h @@ -172,6 +172,10 @@ void Visit(const OMPClause *C); + void VisitTransformExecutableDirective(const TransformExecutableDirective *S); + + void Visit(const TransformClause *C); + void Visit(const BlockDecl::Capture &C); void Visit(const GenericSelectionExpr::ConstAssociation &A); diff --git a/clang/include/clang/Analysis/AnalysisTransform.h b/clang/include/clang/Analysis/AnalysisTransform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Analysis/AnalysisTransform.h @@ -0,0 +1,191 @@ +//===---- AnalysisTransform.h --------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Extract the transformation to apply from a #pragma clang transform AST node. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H +#define LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H + +#include "clang/AST/StmtTransform.h" +#include "clang/Basic/Transform.h" + +namespace clang { + +static bool isTemplateDependent(Expr *E) { + return E->isValueDependent() || E->isTypeDependent() || + E->isInstantiationDependent() || E->containsUnexpandedParameterPack(); +} + +/// Extract which transformation to apply from a TransformExecutableDirective +/// and its clauses. +template struct ExtractTransform { + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + ASTContext &ASTCtx; + const TransformExecutableDirective *Directive; + bool AnyError = false; + bool TemplateDependent = false; + ExtractTransform(ASTContext &ASTCtx, + const TransformExecutableDirective *Directive) + : ASTCtx(ASTCtx), Directive(Directive) { + assert(Directive); + } + + auto DiagError(SourceLocation Loc, unsigned DiagID) { + AnyError = true; + return getDerived().Diag(Loc, DiagID); + } + + template ClauseTy *assumeSingleClause() { + ClauseTy *Result = nullptr; + for (ClauseTy *C : Directive->getClausesOfKind()) { + if (Result) { + DiagError(C->getBeginLoc(), diag::err_sema_transform_clause_one_max) + << TransformClause::getClauseKeyword(C->getKind()) << C->getRange(); + break; + } + + Result = C; + } + + return Result; + } + + llvm::Optional evalIntArg(Expr *E, int MinVal) { + if (isTemplateDependent(E)) { + TemplateDependent = true; + return None; + } + Expr::EvalResult Res; + E->EvaluateAsInt(Res, ASTCtx); + if (!Res.Val.isInt()) { + DiagError(E->getExprLoc(), + diag::err_sema_transform_clause_arg_expect_int); + return None; + } + llvm::APSInt Val = Res.Val.getInt(); + int64_t Int = Val.getSExtValue(); + if (Int < MinVal) { + DiagError(E->getExprLoc(), diag::err_sema_transform_clause_arg_min_val) + << MinVal << SourceRange(E->getBeginLoc(), E->getEndLoc()); + return None; + } + return Int; + } + + void allowedClauses(ArrayRef ClauseKinds) { +#ifndef NDEBUG + for (TransformClause *C : Directive->clauses()) { + assert(llvm::find(ClauseKinds, C->getKind()) != ClauseKinds.end() && + "Parser must have rejected unknown clause"); + } +#endif + } + + static std::unique_ptr wrap(Transform *Trans) { + return std::unique_ptr(Trans); + } + + std::unique_ptr createTransform() { + Transform::Kind Kind = Directive->getTransformKind(); + SourceRange Loc = Directive->getRange(); + + switch (Kind) { + case Transform::LoopUnrollingKind: { + allowedClauses({TransformClause::FullKind, TransformClause::PartialKind}); + FullClause *Full = assumeSingleClause(); + PartialClause *Partial = assumeSingleClause(); + if (Full && Partial) + DiagError(Full->getBeginLoc(), + diag::err_sema_transform_unroll_full_or_partial); + + if (AnyError) + return nullptr; + + if (Full) { + return wrap(LoopUnrollingTransform::createFull(Loc, false, true, true)); + } else if (Partial) { + llvm::Optional Factor = evalIntArg(Partial->getFactor(), 2); + if (AnyError || !Factor.hasValue()) + return nullptr; + return wrap(LoopUnrollingTransform::createPartial( + Loc, false, true, true, Factor.getValue())); + } + + return wrap(LoopUnrollingTransform::create(Loc, false, true, true)); + } + + case Transform::LoopUnrollAndJamKind: { + allowedClauses({TransformClause::PartialKind}); + PartialClause *Partial = assumeSingleClause(); + + if (AnyError) + return nullptr; + + if (Partial) { + llvm::Optional Factor = evalIntArg(Partial->getFactor(), 2); + if (AnyError || !Factor.hasValue()) + return nullptr; + return wrap(LoopUnrollAndJamTransform::createPartial( + Loc, false, true, Factor.getValue())); + } + + return wrap(LoopUnrollAndJamTransform::create(Loc, false, true)); + } + + case clang::Transform::LoopDistributionKind: + allowedClauses({}); + return wrap(LoopDistributionTransform::create(Loc, false)); + + case clang::Transform::LoopVectorizationKind: { + allowedClauses({TransformClause::WidthKind}); + WidthClause *Width = assumeSingleClause(); + if (AnyError) + return nullptr; + + int64_t Simdlen = 0; + if (Width) { + llvm::Optional WidthInt = evalIntArg(Width->getWidth(), 2); + if (AnyError || !WidthInt.hasValue()) + return nullptr; + Simdlen = WidthInt.getValue(); + } + + return wrap(LoopVectorizationTransform::Create(Loc, true, Simdlen, None)); + } + + case clang::Transform::LoopInterleavingKind: { + allowedClauses({TransformClause::FactorKind}); + FactorClause *Factor = assumeSingleClause(); + if (AnyError) + return nullptr; + + int64_t InterleaveFactor = 0; + if (Factor) { + llvm::Optional FactorInt = evalIntArg(Factor->getFactor(), 2); + if (AnyError || !FactorInt.hasValue()) + return nullptr; + InterleaveFactor = FactorInt.getValue(); + } + + return wrap( + LoopInterleavingTransform::Create(Loc, true, InterleaveFactor)); + } + default: + llvm_unreachable("unimplemented"); + } + } +}; + +} // namespace clang +#endif /* LLVM_CLANG_ANALYSIS_ANALYSISTRANSFORM_H */ diff --git a/clang/include/clang/Analysis/TransformedTree.h b/clang/include/clang/Analysis/TransformedTree.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Analysis/TransformedTree.h @@ -0,0 +1,1316 @@ +//===--- TransformedTree.h --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Applies code transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H +#define LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H + +#include "clang/AST/OpenMPClause.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/AST/StmtOpenMP.h" +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Basic/DiagnosticSema.h" +#include "llvm/ADT/SmallVector.h" + +namespace clang { +template class TransformedTreeBuilder; + +struct DefaultExtractTransform : ExtractTransform { + DefaultExtractTransform(ASTContext &ASTCtx, + const TransformExecutableDirective *Directive) + : ExtractTransform(ASTCtx, Directive) {} + + // Ignore any diagnostic and its arguments. + struct DummyDiag { + template DummyDiag operator<<(const T &) const { return {}; } + }; + DummyDiag Diag(SourceLocation Loc, unsigned DiagID) { return {}; } +}; + +/// Represents an input of a code representation. +/// Current can reference the input code only by the AST node, but in the future +/// loops can also given identifiers to reference them. +class TransformInput { + const Stmt *StmtInput = nullptr; + + TransformInput(const Stmt *StmtInput) : StmtInput(StmtInput) {} + +public: + TransformInput() {} + + static TransformInput createByStmt(const Stmt *StmtInput) { + assert(StmtInput); + return TransformInput(StmtInput); + } + + bool isByStmt() const { return StmtInput; } + + const Stmt *getStmtInput() const { return StmtInput; } +}; + +/// Represents a transformation together with the input loops. +/// in the future it will also identify the generated loop. +struct NodeTransform { + Transform *Trans = nullptr; + llvm::SmallVector Inputs; + + NodeTransform() {} + NodeTransform(Transform *Trans, TransformInput TopLevelInput) : Trans(Trans) { + assert(Trans->getNumInputs() >= 1); + + Inputs.resize(Trans->getNumInputs()); + setInput(0, TopLevelInput); + }; + + void setInput(int Idx, TransformInput Input) { Inputs[Idx] = Input; } +}; + +/// This class represents a loop in a loop nest to which transformations are +/// applied to. It is intended to be instantiated for it specific purpose, that +/// is SemaTransformedTree for semantic analysis (consistency warnings and +/// errors) and CGTransformedTree for emitting IR. +template class TransformedTree { + template friend class TransformedTreeBuilder; + using NodeTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + +protected: + /// Is this the root node of the loop hierarchy? + bool IsRoot = false; + + /// Nested loops. + llvm::SmallVector Subloops; + + /// Origin of this loop. + /// @{ + + /// If not the result of the transformation, this is the loop statement that + /// this node represents. + Stmt *Original; + + Derived *BasedOn; + int FollowupRole; + + /// @} + + /// Things applied to this loop. + /// @{ + + /// If this is the primary transformation input. + Transform *TransformedBy = nullptr; + + // Primary/secondary transformation input + Derived *PrimaryInput = nullptr; + + llvm::SmallVector Followups; + Derived *Successor = nullptr; + int InputRole = -1; + + bool IsParallel = false; + bool HasLegacyDisable = false; + /// @} + +protected: + TransformedTree(llvm::ArrayRef SubLoops, Derived *BasedOn, + clang::Stmt *Original, int FollowupRole) + : Subloops(SubLoops.begin(), SubLoops.end()), Original(Original), + BasedOn(BasedOn), FollowupRole(FollowupRole) {} + +public: + ArrayRef getSubLoops() const { return Subloops; } + + Derived *getPrimaryInput() const { return PrimaryInput; } + Transform *getTransformedBy() const { return TransformedBy; } + + /// Return the transformation that generated this loop. Return nullptr if not + /// the result of any transformation, i.e. it is an original loop. + Transform *getSourceTransformation() const { + assert(!BasedOn == isOriginal() && + "Non-original loops must be based on some other loop"); + if (isOriginal()) + return nullptr; + + assert(BasedOn); + assert(BasedOn->isTransformationInput()); + Transform *Result = BasedOn->PrimaryInput->TransformedBy; + assert(Result && + "Non-original loops must have a generating transformation"); + return Result; + } + + Stmt *getOriginal() const { return Original; } + Stmt *getInheritedOriginal() const { + if (Original) + return Original; + if (BasedOn) + return BasedOn->getInheritedOriginal(); + return nullptr; + } + + Derived *getBasedOn() const { return BasedOn; } + + void markParallel() { IsParallel = true; } + bool isParallel() const { return IsParallel; } + + bool isRoot() const { return IsRoot; } + + Derived *getSuccessor() { return Successor; } + + Derived *getLatestSuccessor() { + // If the loop is not being consumed, this is the latest successor. + if (!isTransformationInput()) + return &getDerived(); + // It is possible for a loop consumed into a non-loop, such that it has no + // successor. + if (!Successor) + return nullptr; + return Successor->getLatestSuccessor(); + } + + const Derived *getLatestSuccessor() const { + return const_cast(this)->getLatestSuccessor(); + } + + bool isOriginal() const { return Original; } + + bool isTransformationInput() const { + bool Result = InputRole >= 0; + assert(Result == (PrimaryInput != nullptr)); + return Result; + } + + bool isTransformationFollowup() const { + bool Result = FollowupRole >= 0; + assert(Result == (BasedOn != nullptr)); + return Result; + } + + bool isPrimaryInput() const { + bool Result = (InputRole == 0); + assert(Result == (PrimaryInput == this)); + return PrimaryInput == this; + } + + void applyTransformation(Transform *Trans, + llvm::ArrayRef Followups, + Derived *Successor) { + assert(!isTransformationInput()); + + this->TransformedBy = Trans; + this->Followups.insert(this->Followups.end(), Followups.begin(), + Followups.end()); + this->Successor = Successor; + this->PrimaryInput = &getDerived(); + this->InputRole = 0; // for primary + +#ifndef NDEBUG + assert(isTransformationInput() && isPrimaryInput()); + for (NodeTy *S : Followups) { + assert(S->BasedOn == &getDerived()); + } +#endif + } + + void applySuccessors(Derived *PrimaryInput, int InputRole, + llvm::ArrayRef Followups, + Derived *Successor) { + assert(!isTransformationInput()); + assert(InputRole > 0); + + this->PrimaryInput = PrimaryInput; + this->Followups.insert(this->Followups.end(), Followups.begin(), + Followups.end()); + this->Successor = Successor; + this->InputRole = InputRole; + +#ifndef NDEBUG + assert(isTransformationInput() && !isPrimaryInput()); + for (NodeTy *S : Followups) { + assert(S->BasedOn == &getDerived()); + } +#endif + } +}; + +/// Constructs a loop nest from source and applies transformations on it. +template class TransformedTreeBuilder { + using BuilderTy = Derived; + + Derived &getDerived() { return *static_cast(this); } + const Derived &getDerived() const { + return *static_cast(this); + } + + ASTContext &ASTCtx; + llvm::SmallVectorImpl &AllNodes; + llvm::SmallVectorImpl &AllTransforms; + +private: + /// Build the original loop nest hierarchy from the AST. + void buildPhysicalLoopTree(Stmt *S, SmallVectorImpl &SubLoops, + llvm::DenseMap &StmtToTree) { + if (!S) + return; + + Stmt *Body; + switch (S->getStmtClass()) { + case Stmt::ForStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::WhileStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::DoStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CXXForRangeStmtClass: + Body = cast(S)->getBody(); + break; + case Stmt::CapturedStmtClass: + buildPhysicalLoopTree(cast(S)->getCapturedStmt(), SubLoops, + StmtToTree); + return; + case Expr::LambdaExprClass: + // Call to getBody materializes its body, children() (which is called in + // the default case) does not. + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + case Expr::BlockExprClass: + buildPhysicalLoopTree(cast(S)->getBody(), SubLoops, + StmtToTree); + return; + default: + if (auto *O = dyn_cast(S)) { + if (!O->hasAssociatedStmt()) + return; + Stmt *Associated = O->getAssociatedStmt(); + buildPhysicalLoopTree(Associated, SubLoops, StmtToTree); + return; + } + + for (Stmt *Child : S->children()) + buildPhysicalLoopTree(Child, SubLoops, StmtToTree); + + return; + } + + SmallVector SubSubLoops; + buildPhysicalLoopTree(Body, SubSubLoops, StmtToTree); + + NodeTy *L = getDerived().createPhysical(SubSubLoops, S); + SubLoops.push_back(L); + assert(StmtToTree.count(S) == 0); + StmtToTree[S] = L; + + getDerived().applyOriginal(L); + } + + /// Collect all loop transformations in the function's AST. + class CollectTransformationsVisitor + : public RecursiveASTVisitor { + + Derived &Builder; + llvm::DenseMap &StmtToTree; + + public: + CollectTransformationsVisitor(Derived &Builder, + llvm::DenseMap &StmtToTree) + : Builder(Builder), StmtToTree(StmtToTree) {} + + /// Transformations collected so far. + llvm::SmallVector Transforms; + + bool shouldTraversePostOrder() const { return true; } + + /// Read and apply LoopHint (#pragma clang loop) attributes. + void applyAttributed(const Stmt *S, ArrayRef Attrs, + NodeTy *L) { + + /// State of loop vectorization or unrolling. + enum LVEnableState { Unspecified, Enable, Disable, Full }; + + /// Value for llvm.loop.vectorize.enable metadata. + LVEnableState VectorizeEnable = Unspecified; + bool VectorizeAssumeSafety = false; + + /// Value for llvm.loop.unroll.* metadata (enable, disable, or full). + LVEnableState UnrollEnable = Unspecified; + + /// Value for llvm.loop.unroll_and_jam.* metadata (enable, disable, or + /// full). + LVEnableState UnrollAndJamEnable = Unspecified; + + /// Value for llvm.loop.vectorize.predicate metadata + LVEnableState VectorizePredicateEnable = Unspecified; + + /// Value for llvm.loop.vectorize.width metadata. + unsigned VectorizeWidth = 0; + + /// Value for llvm.loop.interleave.count metadata. + unsigned InterleaveCount = 0; + + /// llvm.unroll. + unsigned UnrollCount = 0; + + /// llvm.unroll. + unsigned UnrollAndJamCount = 0; + + /// Value for llvm.loop.distribute.enable metadata. + LVEnableState DistributeEnable = Unspecified; + + /// Value for llvm.loop.pipeline.disable metadata. + bool PipelineDisabled = false; + + /// Value for llvm.loop.pipeline.iicount metadata. + unsigned PipelineInitiationInterval = 0; + + SourceLocation DistLoc; + SourceLocation VecLoc; + SourceLocation UnrollLoc; + SourceLocation UnrollAndJamLoc; + SourceLocation PipelineLoc; + for (const Attr *Attr : Attrs) { + const LoopHintAttr *LH = dyn_cast(Attr); + const OpenCLUnrollHintAttr *OpenCLHint = + dyn_cast(Attr); + + // Skip non loop hint attributes + if (!LH && !OpenCLHint) { + continue; + } + + LoopHintAttr::OptionType Option = LoopHintAttr::Unroll; + LoopHintAttr::LoopHintState State = LoopHintAttr::Disable; + unsigned ValueInt = 1; + // Translate opencl_unroll_hint attribute argument to + // equivalent LoopHintAttr enums. + // OpenCL v2.0 s6.11.5: + // 0 - enable unroll (no argument). + // 1 - disable unroll. + // other positive integer n - unroll by n. + if (OpenCLHint) { + ValueInt = OpenCLHint->getUnrollHint(); + if (ValueInt == 0) { + State = LoopHintAttr::Enable; + } else if (ValueInt != 1) { + Option = LoopHintAttr::UnrollCount; + State = LoopHintAttr::Numeric; + } + } else if (LH) { + Expr *ValueExpr = LH->getValue(); + if (ValueExpr) { + if (ValueExpr->isValueDependent()) + continue; // Ignore this attribute until it is instantiated. + llvm::APSInt ValueAPS = + ValueExpr->EvaluateKnownConstInt(Builder.ASTCtx); + ValueInt = ValueAPS.getSExtValue(); + } + + Option = LH->getOption(); + State = LH->getState(); + } + + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::VectorizePredicate: + VecLoc = Attr->getLocation(); + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollCount: + UnrollLoc = Attr->getLocation(); + break; + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::UnrollAndJamCount: + UnrollAndJamLoc = Attr->getLocation(); + break; + case LoopHintAttr::Distribute: + DistLoc = Attr->getLocation(); + break; + + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + PipelineLoc = Attr->getLocation(); + break; + } + + switch (State) { + case LoopHintAttr::Disable: + switch (Option) { + case LoopHintAttr::Vectorize: + // Disable vectorization by specifying a width of 1. + VectorizeWidth = 1; + break; + case LoopHintAttr::Interleave: + // Disable interleaving by specifying a count of 1. + InterleaveCount = 1; + break; + case LoopHintAttr::Unroll: + UnrollEnable = Disable; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Disable; + break; + case LoopHintAttr::VectorizePredicate: + VectorizePredicateEnable = Disable; + break; + case LoopHintAttr::Distribute: + DistributeEnable = Disable; + break; + case LoopHintAttr::PipelineDisabled: + PipelineDisabled = true; + break; + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot be disabled."); + break; + } + break; + case LoopHintAttr::Enable: + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + VectorizeEnable = Enable; + break; + case LoopHintAttr::Unroll: + UnrollEnable = Enable; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Enable; + break; + case LoopHintAttr::VectorizePredicate: + VectorizePredicateEnable = Enable; + break; + case LoopHintAttr::Distribute: + DistributeEnable = Enable; + break; + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot enabled."); + break; + } + break; + case LoopHintAttr::AssumeSafety: + switch (Option) { + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + // Apply "llvm.mem.parallel_loop_access" metadata to load/stores. + VectorizeEnable = Enable; + VectorizeAssumeSafety = true; + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::VectorizePredicate: + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + llvm_unreachable("Options cannot be used to assume mem safety."); + break; + } + break; + case LoopHintAttr::Full: + switch (Option) { + case LoopHintAttr::Unroll: + UnrollEnable = Full; + break; + case LoopHintAttr::UnrollAndJam: + UnrollAndJamEnable = Full; + break; + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::UnrollCount: + case LoopHintAttr::UnrollAndJamCount: + case LoopHintAttr::VectorizeWidth: + case LoopHintAttr::InterleaveCount: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + case LoopHintAttr::PipelineInitiationInterval: + case LoopHintAttr::VectorizePredicate: + llvm_unreachable("Options cannot be used with 'full' hint."); + break; + } + break; + case LoopHintAttr::Numeric: + switch (Option) { + case LoopHintAttr::VectorizeWidth: + VectorizeWidth = ValueInt; + break; + case LoopHintAttr::InterleaveCount: + InterleaveCount = ValueInt; + break; + case LoopHintAttr::UnrollCount: + UnrollCount = ValueInt; + break; + case LoopHintAttr::UnrollAndJamCount: + UnrollAndJamCount = ValueInt; + break; + case LoopHintAttr::PipelineInitiationInterval: + PipelineInitiationInterval = ValueInt; + break; + case LoopHintAttr::Unroll: + case LoopHintAttr::UnrollAndJam: + case LoopHintAttr::VectorizePredicate: + case LoopHintAttr::Vectorize: + case LoopHintAttr::Interleave: + case LoopHintAttr::Distribute: + case LoopHintAttr::PipelineDisabled: + llvm_unreachable("Options cannot be assigned a value."); + break; + } + break; + } + } + + bool VectorizeDisabled = + (VectorizeEnable == Disable || VectorizeWidth == 1); + bool InterleaveDisabled = (InterleaveCount == 1); + bool VectorizeInterleaveDisabled = + VectorizeDisabled && InterleaveDisabled && !VectorizeAssumeSafety; + + if (UnrollEnable == Disable) { + L->HasLegacyDisable = true; + Builder.disableUnroll(L); + } + + if (UnrollAndJamEnable == Disable) { + L->HasLegacyDisable = true; + Builder.disableUnrollAndJam(L); + } + + if (DistributeEnable == Disable) { + L->HasLegacyDisable = true; + Builder.disableDistribution(L); + } + + // If the LoopVectorize pass is completely disabled (ie. vectorize and + // interleave). + if (VectorizeInterleaveDisabled) { + L->HasLegacyDisable = true; + Builder.disableVectorizeInterleave(L); + } + + if (PipelineDisabled) { + L->HasLegacyDisable = true; + Builder.disablePipelining(L); + } + + if (UnrollEnable == Full) { + Transform *Trans = + LoopUnrollingTransform::createFull(UnrollLoc, true, false, true); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (DistributeEnable == Enable) { + auto *Trans = LoopDistributionTransform::create(DistLoc, true); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if ((VectorizeEnable == Enable || VectorizeWidth > 0 || + VectorizePredicateEnable != Unspecified || InterleaveCount > 0) && + !VectorizeInterleaveDisabled) { + auto *Trans = LoopVectorizationInterleavingTransform::create( + VecLoc, true, VectorizeAssumeSafety, + VectorizeEnable != Unspecified + ? llvm::Optional(VectorizeEnable == Enable) + : None, + None, VectorizeWidth, + (VectorizePredicateEnable != Unspecified) + ? llvm::Optional(VectorizePredicateEnable == Enable) + : None, + InterleaveCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (UnrollAndJamEnable == Enable || UnrollAndJamEnable == Full || + UnrollAndJamCount > 0) { + LoopUnrollAndJamTransform *Trans; + if (UnrollAndJamEnable == Full) + Trans = LoopUnrollAndJamTransform::createFull(UnrollAndJamLoc, true, + true); + else + Trans = LoopUnrollAndJamTransform::createPartial( + UnrollAndJamLoc, true, UnrollAndJamEnable == Enable, + UnrollAndJamCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (UnrollEnable != Full && (UnrollEnable == Enable || UnrollCount > 0)) { + auto *Trans = LoopUnrollingTransform::createPartial( + UnrollLoc, true, false, UnrollEnable == Enable, UnrollCount); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + + if (PipelineInitiationInterval > 0) { + auto *Trans = LoopPipeliningTransform::create( + PipelineLoc, PipelineInitiationInterval); + Transforms.emplace_back(Trans, TransformInput::createByStmt(S)); + } + } + + bool VisitAttributedStmt(const AttributedStmt *S) { + const Stmt *LoopStmt = getAssociatedLoop(S); + + // Not every attributed statement is associated with a loop. + if (!LoopStmt) + return true; + + NodeTy *Node = StmtToTree.lookup(LoopStmt); + assert(Node && "We should have created a node for ever loop"); + applyAttributed(LoopStmt, S->getAttrs(), Node); + return true; + } + + bool + VisitTransformExecutableDirective(const TransformExecutableDirective *S) { + // This ExtractTransform does not emit any diagnostics. Diagnostics should + // have been emitted in Sema::ActOnTransformExecutableDirective. + DefaultExtractTransform ExtractTransform(Builder.ASTCtx, S); + std::unique_ptr Trans = ExtractTransform.createTransform(); + + // We might not get a transform in non-instantiated templates or + // inconsistent clauses. + if (Trans) { + const Stmt *TheLoop = getAssociatedLoop(S->getAssociated()); + Transforms.emplace_back(Trans.get(), + TransformInput::createByStmt(TheLoop)); + Builder.AllTransforms.push_back(Trans.release()); + } + + return true; + } + + bool handleOMPLoopClauses(OMPLoopDirective *Directive, bool HasTaskloop, + bool HasFor, bool HasSimd) { + assert((!HasTaskloop || !HasFor) && + "taskloop and for are mutually exclusive"); + const Stmt *TopLevel = getAssociatedLoop(Directive); + + bool IsMonotonic = true; + if (HasFor) { + if (auto *ScheduleClause = + Directive->getSingleClause()) { + OpenMPScheduleClauseKind Schedule = ScheduleClause->getScheduleKind(); + if (Schedule != OMPC_SCHEDULE_unknown && + Schedule != OMPC_SCHEDULE_static) + IsMonotonic = false; + + if (ScheduleClause->getFirstScheduleModifier() == + OMPC_SCHEDULE_MODIFIER_monotonic) + IsMonotonic = true; + if (ScheduleClause->getSecondScheduleModifier() == + OMPC_SCHEDULE_MODIFIER_monotonic) + IsMonotonic = true; + } + + if (auto *OrderedClause = + Directive->getSingleClause()) { + if (!OrderedClause->getNumForLoops()) + IsMonotonic = true; + } + } + + if (HasTaskloop) + IsMonotonic = false; + + int64_t SimdWidth = 0; + bool SimdImplicitParallel = false; + if (HasSimd) { + if (!HasFor) + SimdImplicitParallel = true; + + if (auto *LenClause = Directive->getSingleClause()) { + auto SafelenExpr = LenClause->getSafelen(); + if (!SafelenExpr->isValueDependent() && + SafelenExpr->isEvaluatable(Builder.ASTCtx)) { + llvm::APSInt ValueAPS = + SafelenExpr->EvaluateKnownConstInt(Builder.ASTCtx); + SimdWidth = ValueAPS.getSExtValue(); + } + + // In presence of finite 'safelen', it may be unsafe to mark all + // the memory instructions parallel, because loop-carried + // dependences of 'safelen' iterations are possible. + SimdImplicitParallel = false; + IsMonotonic = true; + } + + if (auto LenClause = Directive->getSingleClause()) { + auto SimdlenExpr = LenClause->getSimdlen(); + if (!SimdlenExpr->isValueDependent() && + SimdlenExpr->isEvaluatable(Builder.ASTCtx)) { + llvm::APSInt ValueAPS = + SimdlenExpr->EvaluateKnownConstInt(Builder.ASTCtx); + SimdWidth = ValueAPS.getSExtValue(); + } + } + } + + bool IsImplicitParallel = !IsMonotonic || SimdImplicitParallel; + + if (HasSimd) { + auto *Trans = LoopVectorizationInterleavingTransform::create( + Directive->getSourceRange(), true, IsImplicitParallel, true, None, + SimdWidth, None, 0); + Transforms.emplace_back(Trans, TransformInput::createByStmt(TopLevel)); + } else if (IsImplicitParallel) { + auto *Assume = + LoopAssumeParallelTransform::create(Directive->getSourceRange()); + Transforms.emplace_back(Assume, TransformInput::createByStmt(TopLevel)); + } + + return true; + } + + bool VisitOMPForDirective(OMPForDirective *For) { + return handleOMPLoopClauses(For, false, true, false); + } + + bool VisitOMPDistributeParallelForDirective( + OMPDistributeParallelForDirective *L) { + return handleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPTeamsDistributeParallelForDirective( + OMPTeamsDistributeParallelForDirective *L) { + return handleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPTargetTeamsDistributeParallelForDirective( + OMPTargetTeamsDistributeParallelForDirective *L) { + return handleOMPLoopClauses(L, false, true, false); + } + + bool VisitOMPSimdDirective(OMPLoopDirective *Simd) { + return handleOMPLoopClauses(Simd, false, false, true); + } + + bool VisitOMPForSimdDirective(OMPForSimdDirective *ForSimd) { + return handleOMPLoopClauses(ForSimd, false, true, true); + } + + bool + VisitOMPParallelForSimdDirective(OMPParallelForSimdDirective *ForSimd) { + return handleOMPLoopClauses(ForSimd, false, true, true); + } + + bool VisitOMPDistributeSimdDirective( + OMPDistributeSimdDirective *DistributeSimd) { + return handleOMPLoopClauses(DistributeSimd, false, false, true); + } + + bool VisitOMPDistributeParallelForSimdDirective( + OMPDistributeParallelForSimdDirective *DistributeSimd) { + return handleOMPLoopClauses(DistributeSimd, false, true, true); + } + + bool VisitOMPTeamsDistributeSimdDirective( + OMPTeamsDistributeSimdDirective *TeamsDistributeSimd) { + return handleOMPLoopClauses(TeamsDistributeSimd, false, false, true); + } + + bool VisitOMPTeamsDistributeParallelForSimdDirective( + OMPTeamsDistributeParallelForSimdDirective + *TeamsDistributeParallelForSimd) { + return handleOMPLoopClauses(TeamsDistributeParallelForSimd, false, true, + true); + } + + bool VisitOMPTargetTeamsDistributeSimdDirective( + OMPTargetTeamsDistributeSimdDirective *TargetTeamsDistributeSimd) { + return handleOMPLoopClauses(TargetTeamsDistributeSimd, false, false, + true); + } + + bool VisitOMPTargetTeamsDistributeParallelForSimdDirective( + OMPTargetTeamsDistributeParallelForSimdDirective + *TargetTeamsDistributeParallelForSimd) { + return handleOMPLoopClauses(TargetTeamsDistributeParallelForSimd, false, + true, true); + } + + bool VisitOMPTaskLoopSimdDirective( + OMPTaskLoopSimdDirective *TaskLoopSimdDirective) { + return handleOMPLoopClauses(TaskLoopSimdDirective, true, false, true); + } + + bool + VisitOMPTargetSimdDirective(OMPTargetSimdDirective *TargetSimdDirective) { + return handleOMPLoopClauses(TargetSimdDirective, false, false, true); + } + + bool VisitOMPTargetParallelForSimdDirective( + OMPTargetParallelForSimdDirective *TargetParallelForSimdDirective) { + return handleOMPLoopClauses(TargetParallelForSimdDirective, false, true, + true); + } + + bool VisitOMPMasterTaskLoopDirective( + OMPMasterTaskLoopDirective *MasterTaskloop) { + return handleOMPLoopClauses(MasterTaskloop, true, false, false); + } + + bool VisitOMPMasterTaskLoopSimdDirective( + OMPMasterTaskLoopSimdDirective *MasterTaskloopSimd) { + return handleOMPLoopClauses(MasterTaskloopSimd, true, false, true); + } + + bool VisitOMPParallelMasterTaskLoopDirective( + OMPParallelMasterTaskLoopDirective *ParallelMasterTaskloop) { + return handleOMPLoopClauses(ParallelMasterTaskloop, true, false, false); + } + + bool VisitOMPParallelMasterTaskLoopSimdDirective( + OMPParallelMasterTaskLoopSimdDirective *ParallelMasterTaskloopSimd) { + return handleOMPLoopClauses(ParallelMasterTaskloopSimd, true, false, + true); + } + }; + + /// Applies collected transformations to the loop nest representation. + struct TransformApplicator { + Derived &Builder; + llvm::DenseMap> + TransByStmt; + + TransformApplicator(Derived &Builder) : Builder(Builder) {} + + void addNodeTransform(NodeTransform *NT) { + TransformInput &TopLevelInput = NT->Inputs[0]; + TransByStmt[TopLevelInput.getStmtInput()].push_back(NT); + } + + void checkStageOrder(ArrayRef PrevLoops, Transform *NewTrans) { + for (NodeTy *PrevLoop : PrevLoops) { + // Cannot combine legacy disable pragmas (e.g. #pragma clang loop + // unroll(disable)) and new transformations (#pragma clang transform). + if (!NewTrans->isLegacy() && PrevLoop->HasLegacyDisable) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::err_sema_transform_legacy_mix); + return; + } + + Transform *PrevSourceTrans = PrevLoop->getSourceTransformation(); + if (!PrevSourceTrans) + continue; + + // Cannot combine legacy constructs (#pragma clang loop, ...) with new + // ones (#pragma clang transform). + if (NewTrans->isLegacy() != PrevSourceTrans->isLegacy()) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::err_sema_transform_legacy_mix); + return; + } + + int PrevStage = PrevSourceTrans->getLoopPipelineStage(); + int NewStage = NewTrans->getLoopPipelineStage(); + if (PrevStage >= 0 && NewStage >= 0 && PrevStage > NewStage) { + Builder.Diag(NewTrans->getBeginLoc(), + diag::warn_sema_transform_pass_order); + + // At most one warning per transformation. + return; + } + } + } + + NodeTy *applyTransform(Transform *Trans, NodeTy *MainLoop) { + switch (Trans->getKind()) { + case Transform::Kind::LoopUnrollingKind: + return applyUnrolling(cast(Trans), MainLoop); + case Transform::Kind::LoopUnrollAndJamKind: + return applyUnrollAndJam(cast(Trans), + MainLoop); + case Transform::Kind::LoopDistributionKind: + return applyDistribution(cast(Trans), + MainLoop); + case Transform::Kind::LoopVectorizationInterleavingKind: + return applyVectorizeInterleave( + cast(Trans), MainLoop); + case Transform::Kind::LoopVectorizationKind: + return applyVectorize(cast(Trans), + MainLoop); + case Transform::Kind::LoopInterleavingKind: + return applyInterleave(cast(Trans), + MainLoop); + case Transform::Kind::LoopPipeliningKind: + return applyPipelining(cast(Trans), MainLoop); + case Transform::Kind::LoopAssumeParallelKind: + return applyAssumeParallel(cast(Trans), + MainLoop); + default: + llvm_unreachable("unimplemented transformation"); + } + } + + void inheritLoopAttributes(NodeTy *Dst, NodeTy *Src, bool IsAll, + bool IsSuccessor) { + Builder.inheritLoopAttributes(Dst, Src, IsAll, IsSuccessor); + } + + NodeTy *applyUnrolling(LoopUnrollingTransform *Trans, NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *Successor = nullptr; + if (Trans->isFull()) { + // Full unrolling has no followup-loop. + MainLoop->applyTransformation(Trans, {}, nullptr); + } else { + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopUnrollingTransform::FollowupAll); + NodeTy *Unrolled = + Builder.createFollowup(MainLoop->Subloops, MainLoop, + LoopUnrollingTransform::FollowupUnrolled); + NodeTy *Remainder = + Builder.createFollowup(MainLoop->Subloops, MainLoop, + LoopUnrollingTransform::FollowupRemainder); + Successor = Trans->isLegacy() ? All : Unrolled; + + inheritLoopAttributes(All, MainLoop, true, All == Successor); + MainLoop->applyTransformation(Trans, {All, Unrolled, Remainder}, + Successor); + } + + Builder.applyUnroll(Trans, MainLoop); + return Successor; + } + + NodeTy *applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + NodeTy *MainLoop) { + // Search for the innermost loop that is being jammed. + NodeTy *Cur = MainLoop; + NodeTy *Inner = nullptr; + if (Cur->Subloops.size() == 1) { + Inner = Cur->Subloops[0]->getLatestSuccessor(); + } else if (!Trans->isLegacy()) { + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (!Trans->isLegacy() && !Inner) { + if (!Trans->isLegacy()) + Builder.Diag( + Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_expect_nested_loop); + return nullptr; + } + + if (!Trans->isLegacy() && Inner->Subloops.size() != 0) { + if (!Trans->isLegacy()) + Builder.Diag(Trans->getBeginLoc(), + diag::err_sema_transform_unrollandjam_not_innermost); + return nullptr; + } + + // Having no loop to jam does not make a lot of sense, but fixes + // regression tests. + if (!Inner) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *UnrolledOuter = Builder.createFollowup( + {}, MainLoop, LoopUnrollAndJamTransform::FollowupOuter); + inheritLoopAttributes(UnrolledOuter, MainLoop, true, false); + + MainLoop->applyTransformation(Trans, {UnrolledOuter}, UnrolledOuter); + Builder.applyUnrollAndJam(Trans, MainLoop, nullptr); + return UnrolledOuter; + } + + checkStageOrder({MainLoop, Inner}, Trans); + + NodeTy *TransformedInner = Builder.createFollowup( + Inner->Subloops, Inner, LoopUnrollAndJamTransform::FollowupInner); + inheritLoopAttributes(TransformedInner, Inner, false, false); + + // TODO: Handle full unrolling + NodeTy *UnrolledOuter = Builder.createFollowup( + {Inner}, MainLoop, LoopUnrollAndJamTransform::FollowupOuter); + inheritLoopAttributes(UnrolledOuter, MainLoop, false, true); + + MainLoop->applyTransformation(Trans, {UnrolledOuter}, UnrolledOuter); + Inner->applySuccessors(MainLoop, LoopUnrollAndJamTransform::InputInner, + {TransformedInner}, TransformedInner); + + Builder.applyUnrollAndJam(Trans, MainLoop, Inner); + return UnrolledOuter; + } + + NodeTy *applyDistribution(LoopDistributionTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, LoopDistributionTransform::FollowupAll); + NodeTy *Successor = Trans->isLegacy() ? All : nullptr; + inheritLoopAttributes(All, MainLoop, true, Successor == All); + + MainLoop->applyTransformation(Trans, {All}, Successor); + Builder.applyDistribution(Trans, MainLoop); + return Successor; + } + + NodeTy * + applyVectorizeInterleave(LoopVectorizationInterleavingTransform *Trans, + NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + + NodeTy *All = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupAll); + NodeTy *Vectorized = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupVectorized); + NodeTy *Epilogue = Builder.createFollowup( + MainLoop->Subloops, MainLoop, + LoopVectorizationInterleavingTransform::FollowupEpilogue); + + NodeTy *Successor = Trans->isLegacy() ? All : Vectorized; + if (Trans->isAssumeSafety() && !MainLoop->IsParallel) { + MainLoop->IsParallel = true; + Builder.markParallel(MainLoop); + } + + inheritLoopAttributes(All, MainLoop, true, All == Successor); + MainLoop->applyTransformation(Trans, {All, Vectorized, Epilogue}, + Successor); + Builder.applyVectorizeInterleave(Trans, MainLoop); + return Successor; + } + + NodeTy *applyVectorize(LoopVectorizationTransform *Trans, + NodeTy *MainLoop) { + auto *VecInterleaveTrans = LoopVectorizationInterleavingTransform::create( + Trans->getRange(), false, false, true, false, Trans->getWidth(), + Trans->isPredicateEnabled(), 1); + return applyVectorizeInterleave(VecInterleaveTrans, MainLoop); + } + + NodeTy *applyInterleave(LoopInterleavingTransform *Trans, + NodeTy *MainLoop) { + auto *VecInterleaveTrans = LoopVectorizationInterleavingTransform::create( + Trans->getRange(), false, false, false, true, 1, None, + Trans->getInterleaveCount()); + return applyVectorizeInterleave(VecInterleaveTrans, MainLoop); + } + + NodeTy *applyPipelining(LoopPipeliningTransform *Trans, NodeTy *MainLoop) { + checkStageOrder({MainLoop}, Trans); + MainLoop->applyTransformation(Trans, {}, nullptr); + Builder.applyPipelining(Trans, MainLoop); + return nullptr; + } + + NodeTy *applyAssumeParallel(LoopAssumeParallelTransform *Assumption, + NodeTy *MainLoop) { + if (MainLoop->isParallel()) + return MainLoop; + + MainLoop->IsParallel = true; + Builder.markParallel(MainLoop); + + return MainLoop; + } + + void applyOne(NodeTy *L, NodeTransform *NT) { + applyTransform(NT->Trans, L); + } + +#if 0 + void applyOne(NodeTy *L, const TransformExecutableDirective *D) { + Transform *Trans = D->getTransform(); + assert(Trans); + applyTransform(Trans, L); + } +#endif + + void traverseSubloops(NodeTy *L) { + // Transform subloops first. + for (NodeTy *SubL : L->getSubLoops()) { + SubL = SubL->getLatestSuccessor(); + if (!SubL) + continue; + traverse(SubL); + } + } + + bool applyTransform(NodeTy *L) { + if (L->isRoot()) + return false; + assert(L == L->getLatestSuccessor() && + "Loop must not have been consumed by another transformation"); + + Stmt *OrigStmt = L->getInheritedOriginal(); + auto TransformsOnStmt = TransByStmt.find(OrigStmt); + if (TransformsOnStmt != TransByStmt.end()) { + auto &List = TransformsOnStmt->second; + if (!List.empty()) { + NodeTransform *Trans = List.front(); + applyOne(L, Trans); + List.erase(List.begin()); + return true; + } + } + + return false; + } + + void traverse(NodeTy *L) { + do { + L = L->getLatestSuccessor(); + if (!L) + break; + traverseSubloops(L); + } while (applyTransform(L)); + } + }; + +protected: + TransformedTreeBuilder(ASTContext &ASTCtx, + llvm::SmallVectorImpl &AllNodes, + llvm::SmallVectorImpl &AllTransforms) + : ASTCtx(ASTCtx), AllNodes(AllNodes), AllTransforms(AllTransforms) {} + + NodeTy *createRoot(llvm::ArrayRef SubLoops) { + auto *Result = new NodeTy(SubLoops, nullptr, nullptr, -1); + AllNodes.push_back(Result); + Result->IsRoot = true; + return Result; + } + + NodeTy *createPhysical(llvm::ArrayRef SubLoops, + clang::Stmt *Original) { + assert(Original); + auto *Result = new NodeTy(SubLoops, nullptr, Original, -1); + AllNodes.push_back(Result); + return Result; + } + + NodeTy *createFollowup(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + int FollowupRole) { + assert(BasedOn); + auto *Result = new NodeTy(SubLoops, BasedOn, nullptr, FollowupRole); + AllNodes.push_back(Result); + return Result; + } + +public: + void markParallel(NodeTy *L) { L->markParallel(); } + + NodeTy * + computeTransformedStructure(Stmt *Body, + llvm::DenseMap &StmtToTree) { + if (!Body) + return nullptr; + + // Create original tree. + SmallVector TopLevelLoops; + buildPhysicalLoopTree(Body, TopLevelLoops, StmtToTree); + NodeTy *Root = getDerived().createRoot(TopLevelLoops); + + // Collect all loop transformations. + CollectTransformationsVisitor Collector(getDerived(), StmtToTree); + Collector.TraverseStmt(Body); + auto &TransformList = Collector.Transforms; + + // Local function to apply every transformation that the predicate returns + // true to. This is to emulate LLVM's pass pipeline that would apply + // transformation in this order. + auto SelectiveApplicator = [this, &TransformList, Root](auto Pred) { + TransformApplicator OMPApplicator(getDerived()); + bool AnyActive = false; + for (NodeTransform &NT : TransformList) { + if (Pred(NT)) { + OMPApplicator.addNodeTransform(&NT); + AnyActive = true; + } + } + + // No traversal needed if no transformations to apply. + if (!AnyActive) + return; + + OMPApplicator.traverse(Root); + + // Report leftover transformations. + for (auto &P : OMPApplicator.TransByStmt) { + for (NodeTransform *NT : P.second) { + if (!NT->Trans->isLegacy()) + getDerived().Diag(NT->Trans->getBeginLoc(), + diag::err_sema_transform_missing_loop); + } + } + + // Remove applied transformations from list. + auto NewEnd = + std::remove_if(TransformList.begin(), TransformList.end(), Pred); + TransformList.erase(NewEnd, TransformList.end()); + }; + + // Apply full unrolling, loop distribution, vectorization/interleaving. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy() && Unroll->isFull(); + if (auto Dist = dyn_cast(NT.Trans)) + return Dist->isLegacy(); + if (auto Vec = dyn_cast(NT.Trans)) + return Vec->isLegacy(); + return false; + }); + + // Apply unrollandjam. + // While all transformations in TransformList have already been added in the + // order in the pass pipeline (relative to transformations on the same + // loop), transformations on unrollandjam's inner loop are not ordered + // relative to the outer loop. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy(); + return false; + }); + + // Apply partial unrolling and software pipelining. + SelectiveApplicator([](NodeTransform &NT) -> bool { + if (auto Unroll = dyn_cast(NT.Trans)) + return Unroll->isLegacy(); + if (isa(NT.Trans)) + return true; + return false; + }); + + // Apply all others. + SelectiveApplicator([](NodeTransform &NT) -> bool { return true; }); + assert(TransformList.size() == 0 && "Must apply all transformations"); + + return Root; + } +}; + +} // namespace clang +#endif /* LLVM_CLANG_ANALYSIS_TRANSFORMEDTREE_H */ diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10070,4 +10070,29 @@ "__builtin_bit_cast %select{source|destination}0 type must be trivially copyable">; def err_bit_cast_type_size_mismatch : Error< "__builtin_bit_cast source size does not equal destination size (%0 vs %1)">; + +// Pragma transform support. +def err_sema_transform_expected_loop : Error< + "cannot find loop to transform">; +def err_sema_transform_unroll_full_or_partial : Error< + "the full and partial clauses are mutually exclusive">; +def err_sema_transform_unrollandjam_expect_nested_loop : Error< + "unroll-and-jam requires exactly one nested loop">; +def err_sema_transform_unrollandjam_not_innermost : Error< + "inner loop of unroll-and-jam is not the innermost">; +def err_sema_transform_missing_loop : Error< + "transformation did not find its loop to transform; it might have been consumed by another transformations">; +def err_sema_transform_clause_one_max : Error< + "the %0 clause can be specified at most once">; +def err_sema_transform_clause_arg_expect_int : Error< + "clause expects an integer argument">; +def err_sema_transform_clause_arg_min_val : Error< + "clause argument must me at least %0">; +def err_sema_transform_legacy_mix : Error< + "Cannot combine #pragma clang transform with other transformations">; + +def warn_sema_transform_pass_order : Warning< + "the LLVM pass structure currently is not able to apply the transformations in this order">, + InGroup; + } // end of sema component. diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -209,6 +209,9 @@ // OpenCL Extensions. def AsTypeExpr : StmtNode; +// Transform Directives. +def TransformExecutableDirective : StmtNode; + // OpenMP Directives. def OMPExecutableDirective : StmtNode; def OMPLoopDirective : StmtNode; diff --git a/clang/include/clang/Basic/Transform.h b/clang/include/clang/Basic/Transform.h --- a/clang/include/clang/Basic/Transform.h +++ b/clang/include/clang/Basic/Transform.h @@ -33,6 +33,402 @@ static Kind getTransformDirectiveKind(llvm::StringRef Str); static llvm::StringRef getTransformDirectiveKeyword(Kind K); static llvm::StringRef getTransformDirectiveName(Kind K); + +private: + Kind TransformKind; + SourceRange LocRange; + bool IsLegacy; + +public: + Transform(Kind K, SourceRange LocRange, bool IsLegacy) + : TransformKind(K), LocRange(LocRange), IsLegacy(IsLegacy) {} + + Kind getKind() const { return TransformKind; } + static bool classof(const Transform *Trans) { return true; } + + /// Source location of the code transformation directive. + /// @{ + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setRange(SourceRange L) { LocRange = L; } + void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + /// @} + + /// Non-legacy directives originate from a #pragma clang transform directive. + /// Legacy transformations are introduced by #pragma clang loop, + /// #pragma omp simd, #pragma unroll, etc. + /// Differences include: + /// * Legacy directives transformation execution order is defined by the + /// compiler. + /// * Some warnings that clang historically did not warn about are disabled. + /// * Some differences of the emitted loop metadata for compatibility. + bool isLegacy() const { return IsLegacy; } + + /// Each transformation defines how many loops it consumes and generates. + /// Users of this class can store arrays holding the information regarding the + /// loops, such as pointer to the AST node or the loop name. The index in this + /// array is its "role". + /// @{ + int getNumInputs() const; + int getNumFollowups() const; + /// @} + + /// The "all" follow-up role is a meta output whose' attributes are added to + /// all generated loops. + bool isAllRole(int R) const { return R == 0; } + + /// Used to warn users that the current LLVM pass pipeline cannot apply + /// arbitrary transformation orders yet. + int getLoopPipelineStage() const; +}; + +/// Default implementation of compile-time inherited methods to avoid infinite +/// recursion. +class TransformImpl : public Transform { +public: + TransformImpl(Kind K, SourceRange Loc, bool IsLegacy) + : Transform(K, Loc, IsLegacy) {} + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 1; } +}; + +class LoopDistributionTransform final : public TransformImpl { +private: + LoopDistributionTransform(SourceRange Loc, bool IsLegacy) + : TransformImpl(LoopDistributionKind, Loc, IsLegacy) {} + +public: + static bool classof(const LoopDistributionTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopDistributionKind; + } + + static LoopDistributionTransform *create(SourceRange Loc, bool IsLegacy) { + return new LoopDistributionTransform(Loc, IsLegacy); + } + + enum Input { + InputToDistribute, + }; + enum Followup { + FollowupAll, + }; +}; + +class LoopVectorizationTransform final : public TransformImpl { +private: + llvm::Optional EnableVectorization; + int VectorizeWidth; + llvm::Optional VectorizePredicateEnable; + + LoopVectorizationTransform(SourceRange Loc, + llvm::Optional EnableVectorization, + int VectorizeWidth, + llvm::Optional VectorizePredicateEnable) + : TransformImpl(LoopVectorizationKind, Loc, true), + EnableVectorization(EnableVectorization), + VectorizeWidth(VectorizeWidth), + VectorizePredicateEnable(VectorizePredicateEnable) {} + +public: + static bool classof(const LoopVectorizationTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopVectorizationKind; + } + + static LoopVectorizationTransform * + Create(SourceRange Loc, llvm::Optional EnableVectorization, + int VectorizeWidth, llvm::Optional VectorizePredicateEnable) { + assert(EnableVectorization.getValueOr(true)); + return new LoopVectorizationTransform( + Loc, EnableVectorization, VectorizeWidth, VectorizePredicateEnable); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + llvm::Optional isVectorizationEnabled() const { + return EnableVectorization; + } + + int getWidth() const { return VectorizeWidth; } + + llvm::Optional isPredicateEnabled() const { + return VectorizePredicateEnable; + } +}; + +class LoopInterleavingTransform final : public TransformImpl { +private: + llvm::Optional EnableInterleaving; + int InterleaveCount; + + LoopInterleavingTransform(SourceRange Loc, + llvm::Optional EnableInterleaving, + int InterleaveCount) + : TransformImpl(LoopInterleavingKind, Loc, false), + EnableInterleaving(EnableInterleaving), + InterleaveCount(InterleaveCount) {} + +public: + static bool classof(const LoopInterleavingTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopInterleavingKind; + } + + static LoopInterleavingTransform * + Create(SourceRange Loc, llvm::Optional EnableInterleaving, + int InterleaveCount) { + assert(EnableInterleaving.getValueOr(true)); + return new LoopInterleavingTransform(Loc, EnableInterleaving, + InterleaveCount); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + llvm::Optional isInterleavingEnabled() const { + return EnableInterleaving; + } + + int getInterleaveCount() const { return InterleaveCount; } +}; + +class LoopVectorizationInterleavingTransform final : public TransformImpl { +private: + bool AssumeSafety; + llvm::Optional EnableVectorization; + llvm::Optional EnableInterleaving; + int VectorizeWidth; + llvm::Optional VectorizePredicateEnable; + int InterleaveCount; + + LoopVectorizationInterleavingTransform( + SourceRange Loc, bool IsLegacy, bool AssumeSafety, + llvm::Optional EnableVectorization, + llvm::Optional EnableInterleaving, int VectorizeWidth, + llvm::Optional VectorizePredicateEnable, int InterleaveCount) + : TransformImpl(LoopVectorizationInterleavingKind, Loc, IsLegacy), + AssumeSafety(AssumeSafety), EnableVectorization(EnableVectorization), + EnableInterleaving(EnableInterleaving), VectorizeWidth(VectorizeWidth), + VectorizePredicateEnable(VectorizePredicateEnable), + InterleaveCount(InterleaveCount) {} + +public: + static bool classof(const LoopVectorizationInterleavingTransform *Trans) { + return true; + } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopVectorizationInterleavingKind; + } + + static LoopVectorizationInterleavingTransform * + create(SourceRange Loc, bool Legacy, bool AssumeSafety, + llvm::Optional EnableVectorization, + llvm::Optional EnableInterleaving, int VectorizeWidth, + llvm::Optional VectorizePredicateEnable, int InterleaveCount) { + assert(EnableVectorization.getValueOr(true) || + EnableInterleaving.getValueOr(true)); + return new LoopVectorizationInterleavingTransform( + Loc, Legacy, AssumeSafety, EnableVectorization, EnableInterleaving, + VectorizeWidth, VectorizePredicateEnable, InterleaveCount); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToVectorize, + }; + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + + bool isAssumeSafety() const { return AssumeSafety; } + llvm::Optional isVectorizationEnabled() const { + return EnableVectorization; + } + llvm::Optional isInterleavingEnabled() const { + return EnableInterleaving; + } + + int getWidth() const { return VectorizeWidth; } + llvm::Optional isPredicateEnabled() const { + return VectorizePredicateEnable; + } + int getInterleaveCount() const { return InterleaveCount; } +}; + +class LoopUnrollingTransform final : public TransformImpl { +private: + bool ImplicitEnable; + bool ExplicitEnable; + int64_t Factor; + + LoopUnrollingTransform(SourceRange Loc, bool IsLegacy, bool ImplicitEnable, + bool ExplicitEnable, int Factor) + : TransformImpl(LoopUnrollingKind, Loc, IsLegacy), + ImplicitEnable(ImplicitEnable), ExplicitEnable(ExplicitEnable), + Factor(Factor) {} + +public: + static bool classof(const LoopUnrollingTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollingKind; + } + + static LoopUnrollingTransform *create(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, 0); + } + static LoopUnrollingTransform *createFull(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, -1); + } + static LoopUnrollingTransform *createPartial(SourceRange Loc, bool Legacy, + bool ImplicitEnable, + bool ExplicitEnable, + int Factor) { + return new LoopUnrollingTransform(Loc, Legacy, ImplicitEnable, + ExplicitEnable, Factor); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputToUnroll, + }; + enum Followup { + FollowupAll, // if not full + FollowupUnrolled, // if not full + FollowupRemainder + }; + + bool hasFollowupRole(int i) { return isFull(); } + int getDefaultSuccessor() const { + return isLegacy() ? FollowupAll : FollowupUnrolled; + } + + bool isImplicitEnable() const { return ImplicitEnable; } + + bool isExplicitEnable() const { return ExplicitEnable; } + + int getFactor() const { return Factor; } + + bool isFull() const { return Factor == -1; } +}; + +class LoopUnrollAndJamTransform final : public TransformImpl { +private: + bool ExplicitEnable; + int Factor; + + LoopUnrollAndJamTransform(SourceRange Loc, bool IsLegacy, bool ExplicitEnable, + int Factor) + : TransformImpl(LoopUnrollAndJamKind, Loc, IsLegacy), + ExplicitEnable(ExplicitEnable), Factor(Factor) {} + +public: + static bool classof(const LoopUnrollAndJamTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollAndJamKind; + } + + static LoopUnrollAndJamTransform *create(SourceRange Loc, bool Legacy, + bool ExplicitEnable) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, 0); + } + static LoopUnrollAndJamTransform *createFull(SourceRange Loc, bool Legacy, + bool ExplicitEnable) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, -1); + } + static LoopUnrollAndJamTransform * + createPartial(SourceRange Loc, bool Legacy, bool ExplicitEnable, int Factor) { + return new LoopUnrollAndJamTransform(Loc, Legacy, ExplicitEnable, Factor); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 3; } + enum Input { + InputOuter, + InputInner, + }; + enum Followup { FollowupAll, FollowupOuter, FollowupInner }; + + bool isExplicitEnable() const { return ExplicitEnable; } + + int getFactor() const { + assert(Factor >= 0); + return Factor; + } + + bool isFull() const { return Factor == -1; } +}; + +class LoopPipeliningTransform final : public TransformImpl { +private: + int InitiationInterval; + + LoopPipeliningTransform(SourceRange Loc, int InitiationInterval) + : TransformImpl(LoopPipeliningKind, Loc, true), + InitiationInterval(InitiationInterval) {} + +public: + static bool classof(const LoopPipeliningTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopPipeliningKind; + } + + static LoopPipeliningTransform *create(SourceRange Loc, + int InitiationInterval) { + return new LoopPipeliningTransform(Loc, InitiationInterval); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 0; } + enum Input { + InputToPipeline, + }; + + int getInitiationInterval() const { return InitiationInterval; } +}; + +class LoopAssumeParallelTransform final : public TransformImpl { +private: + LoopAssumeParallelTransform(SourceRange Loc) + : TransformImpl(LoopAssumeParallelKind, Loc, false) {} + +public: + static bool classof(const LoopAssumeParallelTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopAssumeParallelKind; + } + + static LoopAssumeParallelTransform *create(SourceRange Loc) { + return new LoopAssumeParallelTransform(Loc); + } + + int getNumInputs() const { return 1; } + int getNumFollowups() const { return 1; } + enum Input { + InputParallel, + }; + enum Followup { + FollowupParallel, + }; }; } // namespace clang diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11641,6 +11641,8 @@ TransformClause *ActOnPartialClause(SourceRange Loc, Expr *Factor); TransformClause *ActOnWidthClause(SourceRange Loc, Expr *Width); TransformClause *ActOnFactorClause(SourceRange Loc, Expr *Factor); + + void HandleLoopTransformations(FunctionDecl *FD); }; /// RAII object that enters a new expression evaluation context. diff --git a/clang/include/clang/Sema/SemaTransform.h b/clang/include/clang/Sema/SemaTransform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Sema/SemaTransform.h @@ -0,0 +1,80 @@ +//===---- SemaTransform.h ------------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Semantic analysis for code transformations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_SEMA_SEMATRANSFORM_H +#define LLVM_CLANG_SEMA_SEMATRANSFORM_H + +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Analysis/TransformedTree.h" +#include "clang/Sema/Sema.h" + +namespace clang { +class Sema; + +class SemaTransformedTree : public TransformedTree { + friend class TransformedTree; + using BaseTy = TransformedTree; + using NodeTy = SemaTransformedTree; + + BaseTy &getBase() { return *this; } + const BaseTy &getBase() const { return *this; } + +public: + SemaTransformedTree(llvm::ArrayRef SubLoops, NodeTy *BasedOn, + clang::Stmt *Original, int FollowupRole) + : TransformedTree(SubLoops, BasedOn, Original, FollowupRole) {} +}; + +class SemaTransformedTreeBuilder + : public TransformedTreeBuilder { + using NodeTy = SemaTransformedTree; + + Sema &Sem; + +public: + SemaTransformedTreeBuilder(ASTContext &ASTCtx, + llvm::SmallVectorImpl &AllNodes, + llvm::SmallVectorImpl &AllTransforms, + Sema &Sem) + : TransformedTreeBuilder(ASTCtx, AllNodes, AllTransforms), Sem(Sem) {} + + auto Diag(SourceLocation Loc, unsigned DiagID) { + return Sem.Diag(Loc, DiagID); + } + + void applyOriginal(SemaTransformedTree *L) {} + + void disableDistribution(SemaTransformedTree *L) {} + void disableVectorizeInterleave(SemaTransformedTree *L) {} + void disableUnrollAndJam(SemaTransformedTree *L) {} + void disableUnroll(SemaTransformedTree *L) {} + void disablePipelining(SemaTransformedTree *L) {} + + void applyDistribution(LoopDistributionTransform *Trans, + SemaTransformedTree *InputLoop) {} + void applyVectorizeInterleave(LoopVectorizationInterleavingTransform *Trans, + SemaTransformedTree *MainLoop) {} + void applyUnrollAndJam(LoopUnrollAndJamTransform *Trans, + SemaTransformedTree *OuterLoop, + SemaTransformedTree *InnerLoop) {} + void applyUnroll(LoopUnrollingTransform *Trans, + SemaTransformedTree *OriginalLoop) {} + void applyPipelining(LoopPipeliningTransform *Trans, + SemaTransformedTree *MainLoop) {} + + void inheritLoopAttributes(SemaTransformedTree *Dst, SemaTransformedTree *Src, + bool IsAll, bool IsSuccessor) {} +}; + +} // namespace clang +#endif /* LLVM_CLANG_SEMA_SEMATRANSFORM_H */ diff --git a/clang/lib/AST/ASTTypeTraits.cpp b/clang/lib/AST/ASTTypeTraits.cpp --- a/clang/lib/AST/ASTTypeTraits.cpp +++ b/clang/lib/AST/ASTTypeTraits.cpp @@ -41,6 +41,9 @@ {NKI_None, "OMPClause"}, #define OPENMP_CLAUSE(TextualSpelling, Class) {NKI_OMPClause, #Class}, #include "clang/Basic/OpenMPKinds.def" + {NKI_TransformClause, "TransformClause"}, +#define TRANSFORM_CLAUSE(Keyword, Name) {NKI_##Name##Clause, #Name "Clause"}, +#include "clang/AST/TransformClauseKinds.def" }; bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const { @@ -124,6 +127,20 @@ llvm_unreachable("invalid stmt kind"); } + + +ASTNodeKind ASTNodeKind::getFromNode(const TransformClause &C) { + switch (C.getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind ::Name##Kind: \ + return ASTNodeKind(NKI_##Name##Clause); +#include "clang/AST/TransformClauseKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("unexpected transform kind"); + } + llvm_unreachable("invalid transform kind"); +} + void DynTypedNode::print(llvm::raw_ostream &OS, const PrintingPolicy &PP) const { if (const TemplateArgument *TA = get()) diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp --- a/clang/lib/AST/JSONNodeDumper.cpp +++ b/clang/lib/AST/JSONNodeDumper.cpp @@ -167,6 +167,10 @@ void JSONNodeDumper::Visit(const OMPClause *C) {} +void JSONNodeDumper::Visit(const TransformClause *C) {} + +void JSONNodeDumper::Visit(const Transform *T) {} + void JSONNodeDumper::Visit(const BlockDecl::Capture &C) { JOS.attribute("kind", "Capture"); attributeOnlyIfTrue("byref", C.isByRef()); diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -30,6 +30,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/TemplateBase.h" #include "clang/AST/Type.h" @@ -635,6 +636,19 @@ if (Policy.IncludeNewlines) OS << NL; } +void StmtPrinter::VisitTransformExecutableDirective( + TransformExecutableDirective *Node) { + Indent() << "#pragma clang transform " + << Transform ::getTransformDirectiveKeyword( + Node->getTransformKind()); + for (TransformClause *Clause : Node->clauses()) { + OS << ' '; + Clause->print(OS, Policy); + } + OS << NL; + PrintStmt(Node->getAssociated()); +} + //===----------------------------------------------------------------------===// // OpenMP directives printing methods //===----------------------------------------------------------------------===// diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -771,6 +771,11 @@ } } +void StmtProfiler::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { + VisitStmt(S); +} + void StmtProfiler::VisitOMPExecutableDirective(const OMPExecutableDirective *S) { VisitStmt(S); diff --git a/clang/lib/AST/StmtTransform.cpp b/clang/lib/AST/StmtTransform.cpp --- a/clang/lib/AST/StmtTransform.cpp +++ b/clang/lib/AST/StmtTransform.cpp @@ -17,6 +17,32 @@ using namespace clang; +TransformExecutableDirective *TransformExecutableDirective::create( + ASTContext &Ctx, SourceRange Range, Stmt *Associated, + ArrayRef Clauses, Transform::Kind TransKind) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(Clauses.size())); + return new (Mem) + TransformExecutableDirective(Range, Associated, Clauses, TransKind); +} + +TransformExecutableDirective * +TransformExecutableDirective::createEmpty(ASTContext &Ctx, + unsigned NumClauses) { + void *Mem = Ctx.Allocate(totalSizeToAlloc(NumClauses)); + return new (Mem) TransformExecutableDirective(NumClauses); +} + +llvm::StringRef TransformClause::getClauseName(Kind K) { + assert(K >= UnknownKind); + assert(K <= LastKind); + const char *Names[LastKind + 1] = { + "Unknown", +#define TRANSFORM_CLAUSE(Keyword, Name) #Name, +#include "clang/AST/TransformClauseKinds.def" + }; + return Names[K]; +} + bool TransformClause::isValidForTransform(Transform::Kind TransformKind, TransformClause::Kind ClauseKind) { switch (TransformKind) { @@ -56,6 +82,56 @@ return ClauseKeyword[ClauseKind - 1]; } +TransformClause ::child_range TransformClause ::children() { + switch (getKind()) { + case UnknownKind: + llvm_unreachable("Unknown child"); +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return static_cast(this)->children(); +#include "clang/AST/TransformClauseKinds.def" + } + llvm_unreachable("Unhandled clause kind"); +} + +void TransformClause ::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static decltype(&TransformClause::print) PrintFuncs[LastKind] = { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + static_cast(&Name##Clause ::print), +#include "clang/AST/TransformClauseKinds.def" + }; + (this->*PrintFuncs[getKind() - 1])(OS, Policy); +} + +void FullClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "full"; +} + +void PartialClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "partial("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void WidthClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "width("; + Width->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + +void FactorClause::print(llvm::raw_ostream &OS, + const PrintingPolicy &Policy) const { + OS << "factor("; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; +} + const Stmt *clang::getAssociatedLoop(const Stmt *S) { switch (S->getStmtClass()) { case Stmt::ForStmtClass: @@ -67,6 +143,9 @@ return getAssociatedLoop(cast(S)->getCapturedStmt()); case Stmt::AttributedStmtClass: return getAssociatedLoop(cast(S)->getSubStmt()); + case Stmt::TransformExecutableDirectiveClass: + return getAssociatedLoop( + cast(S)->getAssociated()); default: if (auto LD = dyn_cast(S)) return getAssociatedLoop(LD->getAssociatedStmt()); diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -320,6 +320,32 @@ OS << " "; } +void TextNodeDumper::VisitTransformExecutableDirective( + const TransformExecutableDirective *S) { +#if 0 + if (S) + AddChild([=] { + for (TransformClause *C : S->clauses()) + Visit( C); + }); +#endif +} + +void TextNodeDumper::Visit(const TransformClause *C) { + if (!C) { + ColorScope Color(OS, ShowColors, NullColor); + OS << "<<>> TransformClause"; + return; + } + { + ColorScope Color(OS, ShowColors, AttrColor); + StringRef ClauseName = TransformClause::getClauseName(C->getKind()); + OS << ClauseName << "Clause"; + } + dumpPointer(C); + dumpSourceRange(C->getRange()); +} + void TextNodeDumper::Visit(const GenericSelectionExpr::ConstAssociation &A) { const TypeSourceInfo *TSI = A.getTypeSourceInfo(); if (TSI) { diff --git a/clang/lib/Basic/Transform.cpp b/clang/lib/Basic/Transform.cpp --- a/clang/lib/Basic/Transform.cpp +++ b/clang/lib/Basic/Transform.cpp @@ -47,3 +47,45 @@ return Keywords[K]; } +int Transform::getLoopPipelineStage() const { + switch (getKind()) { + case Transform::Kind::LoopDistributionKind: + return 1; + case Transform::Kind::LoopUnrollingKind: + return cast(this)->isFull() ? 0 : 4; + case Transform::Kind::LoopInterleavingKind: + case Transform::Kind::LoopVectorizationKind: + case Transform::Kind::LoopVectorizationInterleavingKind: + return 2; + case Transform::Kind::LoopUnrollAndJamKind: + return 3; + default: + return -1; + } +} + +int Transform::getNumInputs() const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static const decltype( + &Transform::getNumInputs) GetNumInputFuncs[Transform::Kind::LastKind] = { +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + static_cast( \ + &Name##Transform ::getNumInputs), +#include "clang/Basic/TransformKinds.def" + }; + return (this->*GetNumInputFuncs[getKind() - 1])(); +} + +int Transform::getNumFollowups() const { + assert(getKind() > UnknownKind); + assert(getKind() <= LastKind); + static const decltype(&Transform::getNumInputs) + GetNumFollowupFuncs[Transform::Kind::LastKind] = { +#define TRANSFORM_DIRECTIVE(Keyword, Name) \ + static_cast( \ + &Name##Transform ::getNumFollowups), +#include "clang/Basic/TransformKinds.def" + }; + return (this->*GetNumFollowupFuncs[getKind() - 1])(); +} diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp --- a/clang/lib/CodeGen/CGStmt.cpp +++ b/clang/lib/CodeGen/CGStmt.cpp @@ -354,6 +354,9 @@ EmitOMPTargetTeamsDistributeSimdDirective( cast(*S)); break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); + break; } } diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -14083,6 +14083,8 @@ "Leftover expressions for odr-use checking"); } + HandleLoopTransformations(FD); + if (!IsInstantiation) PopDeclContext(); diff --git a/clang/lib/Sema/SemaTransform.cpp b/clang/lib/Sema/SemaTransform.cpp --- a/clang/lib/Sema/SemaTransform.cpp +++ b/clang/lib/Sema/SemaTransform.cpp @@ -10,8 +10,11 @@ // //===----------------------------------------------------------------------===// +#include "clang/Sema/SemaTransform.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtTransform.h" +#include "clang/Analysis/AnalysisTransform.h" +#include "clang/Analysis/TransformedTree.h" #include "clang/Basic/Transform.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" @@ -20,30 +23,65 @@ using namespace clang; +struct SemaExtractTransform : ExtractTransform { + Sema &Sem; + SemaExtractTransform(TransformExecutableDirective *Directive, Sema &Sem) + : ExtractTransform(Sem.getASTContext(), Directive), Sem(Sem) {} + + auto Diag(SourceLocation Loc, unsigned DiagID) { + return Sem.Diag(Loc, DiagID); + } +}; + StmtResult -Sema::ActOnLoopTransformDirective(Transform::Kind Kind, +Sema::ActOnLoopTransformDirective(Transform::Kind Kind, llvm::ArrayRef Clauses, Stmt *AStmt, SourceRange Loc) { - // TOOD: implement - return StmtError(); + const Stmt *Loop = getAssociatedLoop(AStmt); + if (!Loop) + return StmtError( + Diag(Loc.getBegin(), diag::err_sema_transform_expected_loop)); + + auto *Result = + TransformExecutableDirective::create(Context, Loc, AStmt, Clauses, Kind); + + // Emit errors and warnings. + SemaExtractTransform VerifyTransform(Result, *this); + VerifyTransform.createTransform(); + + return Result; } TransformClause *Sema::ActOnFullClause(SourceRange Loc) { - // TOOD: implement - return nullptr; + return FullClause::create(Context, Loc); } TransformClause *Sema::ActOnPartialClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + return PartialClause::create(Context, Loc, Factor); } TransformClause *Sema::ActOnWidthClause(SourceRange Loc, Expr *Width) { - // TOOD: implement - return nullptr; + return WidthClause::create(Context, Loc, Width); } TransformClause *Sema::ActOnFactorClause(SourceRange Loc, Expr *Factor) { - // TOOD: implement - return nullptr; + return FactorClause::create(Context, Loc, Factor); +} + +void Sema::HandleLoopTransformations(FunctionDecl *FD) { + if (!FD || FD->isInvalidDecl()) + return; + + // Note: this is called on a template-code and the instantiated code. + llvm::DenseMap StmtToTree; + llvm::SmallVector AllNodes; + llvm::SmallVector AllTransforms; + SemaTransformedTreeBuilder Builder(getASTContext(), AllNodes, AllTransforms, + *this); + Builder.computeTransformedStructure(FD->getBody(), StmtToTree); + + for (auto N : AllNodes) + delete N; + for (auto T : AllTransforms) + delete T; } diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -26,6 +26,7 @@ #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtObjC.h" #include "clang/AST/StmtOpenMP.h" +#include "clang/AST/StmtTransform.h" #include "clang/Sema/Designator.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Ownership.h" @@ -341,6 +342,51 @@ /// \returns the transformed statement. StmtResult TransformStmt(Stmt *S, StmtDiscardKind SDK = SDK_Discarded); + Transform *TransformTransform(Transform *T) { return T; } + + TransformClause *TransformTransformClause(TransformClause *S); + + TransformClause *TransformFullClause(FullClause *C) { + return RebuildFullClause(C->getRange()); + } + + TransformClause *RebuildFullClause(SourceRange Range) { + return getSema().ActOnFullClause(Range); + } + + TransformClause *TransformPartialClause(PartialClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildPartialClause(C->getRange(), E.get()); + } + + TransformClause *RebuildPartialClause(SourceRange Range, Expr *Factor) { + return getSema().ActOnPartialClause(Range, Factor); + } + + TransformClause *TransformWidthClause(WidthClause *C) { + ExprResult E = getDerived().TransformExpr(C->getWidth()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildWidthClause(C->getRange(), E.get()); + } + + TransformClause *RebuildWidthClause(SourceRange Range, Expr *Width) { + return getSema().ActOnWidthClause(Range, Width); + } + + TransformClause *TransformFactorClause(FactorClause *C) { + ExprResult E = getDerived().TransformExpr(C->getFactor()); + if (E.isInvalid()) + return nullptr; + return getDerived().RebuildFactorClause(C->getRange(), E.get()); + } + + TransformClause *RebuildFactorClause(SourceRange Range, Expr *Factor) { + return getSema().ActOnFactorClause(Range, Factor); + } + /// Transform the given statement. /// /// By default, this routine transforms a statement by delegating to the @@ -1503,6 +1549,16 @@ return getSema().BuildObjCAtThrowStmt(AtLoc, Operand); } + StmtResult + RebuildTransformExecutableDirective(Transform::Kind Kind, llvm::ArrayRef Clauses, + Stmt *AStmt, SourceRange Loc) { + StmtResult Result = + getSema().ActOnLoopTransformDirective(Kind, Clauses, AStmt, Loc); + assert(!Result.isUsable() || + isa(Result.get())); + return Result; + } + /// Build a new OpenMP executable directive. /// /// By default, performs semantic analysis to build the new statement. @@ -7864,6 +7920,42 @@ return S; } +template +TransformClause * +TreeTransform::TransformTransformClause(TransformClause *S) { + if (!S) + return S; + + switch (S->getKind()) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + case TransformClause::Kind::Name##Kind: \ + return getDerived().Transform##Name##Clause(cast(S)); +#include "clang/AST/TransformClauseKinds.def" + case TransformClause::Kind::UnknownKind: + llvm_unreachable("Should not be unknown"); + } + + return S; +} + +template +StmtResult TreeTransform::TransformTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm::SmallVector TClauses; + for (TransformClause *C : D->clauses()) { + TransformClause *TClause = getDerived().TransformTransformClause(C); + TClauses.push_back(TClause); + } + + Stmt *AStmt = D->getAssociated(); + StmtResult TBody = getDerived().TransformStmt(AStmt); + assert(TBody.isUsable()); + + StmtResult TDirective = getDerived().RebuildTransformExecutableDirective( + D->getTransformKind(), TClauses, TBody.get(), D->getRange()); + return TDirective; +} + //===----------------------------------------------------------------------===// // OpenMP directive transformation //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2007,6 +2007,15 @@ E->SrcExpr = Record.readSubExpr(); } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtReader::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} + //===----------------------------------------------------------------------===// // OpenMP Directives. //===----------------------------------------------------------------------===// diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -1951,6 +1951,14 @@ Record.AddSourceLocation(S->getLeaveLoc()); Code = serialization::STMT_SEH_LEAVE; } +//===----------------------------------------------------------------------===// +// Transformation Directives. +//===----------------------------------------------------------------------===// + +void ASTStmtWriter::VisitTransformExecutableDirective( + TransformExecutableDirective *D) { + llvm_unreachable("not implemented"); +} //===----------------------------------------------------------------------===// // OpenMP Directives. diff --git a/clang/test/AST/ast-dump-transform-unroll.c b/clang/test/AST/ast-dump-transform-unroll.c new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-dump-transform-unroll.c @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-unknown -fexperimental-transform-pragma -ast-dump %s | FileCheck %s + +void unroll_full(int n) { +#pragma clang transform unroll full + for (int i = 0; i < 4; i+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unroll_full +// CHECK: TransformExecutableDirective +// CHECK-NEXT: FullClause +// CHECK-NEXT: ForStmt + + +void unroll_partial(int n) { +#pragma clang transform unroll partial(4) + for (int i = 0; i < n; i+=1) + ; +} + +// CHECK-LABEL: FunctionDecl {{.*}} unroll_partial +// CHECK: TransformExecutableDirective +// CHECK-NEXT: PartialClause +// CHECK-NEXT: IntegerLiteral +// CHECK-NEXT: ForStmt diff --git a/clang/test/AST/ast-print-pragma-transform-distribute.cpp b/clang/test/AST/ast-print-pragma-transform-distribute.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-distribute.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void distribute(int *List, int Length) { +// CHECK: #pragma clang transform distribute +#pragma clang transform distribute +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-interleave.cpp b/clang/test/AST/ast-print-pragma-transform-interleave.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-interleave.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform interleave factor(4) +#pragma clang transform interleave factor(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-unroll.cpp b/clang/test/AST/ast-print-pragma-transform-unroll.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-unroll.cpp @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unroll partial(4) +#pragma clang transform unroll partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; + +// CHECK: #pragma clang transform unroll full +#pragma clang transform unroll full +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp b/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-unrollandjam.cpp @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform unrollandjam partial(4) +#pragma clang transform unrollandjam partial(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + for (int j = 0; j < Length; j += 1) + List[i] += j; +} diff --git a/clang/test/AST/ast-print-pragma-transform-vectorize.cpp b/clang/test/AST/ast-print-pragma-transform-vectorize.cpp new file mode 100644 --- /dev/null +++ b/clang/test/AST/ast-print-pragma-transform-vectorize.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -fexperimental-transform-pragma -ast-print %s -o - | FileCheck %s + +void unroll(int *List, int Length) { +// CHECK: #pragma clang transform vectorize width(4) +#pragma clang transform vectorize width(4) +// CHECK-NEXT: for (int i = 0; i < Length; i += 1) + for (int i = 0; i < Length; i += 1) + List[i] = i; +} diff --git a/clang/test/SemaCXX/pragma-transform-interleave.cpp b/clang/test/SemaCXX/pragma-transform-interleave.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-interleave.cpp @@ -0,0 +1,15 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void interleave(int *List, int Length, int Value) { + +/* expected-error@+1 {{the factor clause can be specified at most once}} */ +#pragma clang transform interleave factor(4) factor(4) + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform interleave factor(-42) + for (int i = 0; i < Length; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-legacymix.cpp b/clang/test/SemaCXX/pragma-transform-legacymix.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-legacymix.cpp @@ -0,0 +1,61 @@ +// RUN: %clang_cc1 -std=c++11 -fopenmp -fexperimental-transform-pragma -fsyntax-only -verify %s + +void legacymix(int *List, int Length, int Value) { + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma unroll'}} */ +#pragma unroll +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma clang loop'}} */ +#pragma clang loop unroll(enable) +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unroll +#pragma clang loop unroll(enable) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unroll +#pragma clang loop unroll(disable) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{expected a for, while, or do-while loop to follow '#pragma clang loop'}} */ +#pragma clang loop unroll(disable) +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unrollandjam partial(2) + for (int i = 0; i < 8; i++) +#pragma clang loop unroll(disable) + for (int j = 0; j < 16; j++) + List[i] = Value; + +/* expected-error@+3 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma unroll_and_jam(2) + for (int i = 0; i < 8; i++) +#pragma clang transform unroll + for (int j = 0; j < 16; j++) + List[i] = Value; + +/* expected-error@+1 {{Cannot combine #pragma clang transform with other transformations}} */ +#pragma clang transform unroll +#pragma omp simd + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+2 {{statement after '#pragma omp simd' must be a for loop}} */ +#pragma omp simd +#pragma clang transform unroll + for (int i = 0; i < 8; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-unroll.cpp b/clang/test/SemaCXX/pragma-transform-unroll.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-unroll.cpp @@ -0,0 +1,37 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void unroll(int *List, int Length, int Value) { + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform unroll partial(0) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the full clause can be specified at most once}} */ +#pragma clang transform unroll full full + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the partial clause can be specified at most once}} */ +#pragma clang transform unroll partial(4) partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{the full and partial clauses are mutually exclusive}} */ +#pragma clang transform unroll full partial(4) + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-error@+1 {{transformation did not find its loop to transform}} */ +#pragma clang transform unroll +#pragma clang transform unroll full + for (int i = 0; i < 8; i++) + List[i] = Value; + +int f = 4; +/* expected-error@+1 {{clause expects an integer argument}} */ +#pragma clang transform unroll partial(f) + for (int i = 0; i < Length; i+=1) + List[i] = i; + +} diff --git a/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp b/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-unrollandjam.cpp @@ -0,0 +1,30 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void unrollandjam(int *List, int Length, int Value) { +/* expected-error@+1 {{the partial clause can be specified at most once}} */ +#pragma clang transform unrollandjam partial(4) partial(4) + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + List[i] += j*Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{unroll-and-jam requires exactly one nested loop}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) { + for (int j = 0; j < Length; j++) + List[i] += j*Value; + for (int j = 0; j < Length; j++) + List[i] += j*Value; + } + +/* expected-error@+1 {{inner loop of unroll-and-jam is not the innermost}} */ +#pragma clang transform unrollandjam + for (int i = 0; i < Length; i++) + for (int j = 0; j < Length; j++) + for (int k = 0; k < Length; k++) + List[i] += k + j*Value; +} diff --git a/clang/test/SemaCXX/pragma-transform-vectorize.cpp b/clang/test/SemaCXX/pragma-transform-vectorize.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-vectorize.cpp @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void vectorize(int *List, int Length, int Value) { +/* expected-error@+1 {{the width clause can be specified at most once}} */ +#pragma clang transform vectorize width(4) width(4) + for (int i = 0; i < Length; i++) + List[i] = Value; + +/* expected-error@+1 {{clause argument must me at least 2}} */ +#pragma clang transform vectorize width(-42) + for (int i = 0; i < Length; i++) + List[i] = Value; + +} diff --git a/clang/test/SemaCXX/pragma-transform-wrongorder.cpp b/clang/test/SemaCXX/pragma-transform-wrongorder.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/pragma-transform-wrongorder.cpp @@ -0,0 +1,18 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -fsyntax-only -verify %s + +void wrongorder(int *List, int Length, int Value) { + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform distribute +#pragma clang transform vectorize + for (int i = 0; i < 8; i++) + List[i] = Value; + +/* expected-warning@+1 {{the LLVM pass structure currently is not able to apply the transformations in this order}} */ +#pragma clang transform vectorize +#pragma clang transform unrollandjam + for (int i = 0; i < 8; i++) + for (int j = 0; j < 8; j++) + List[i] += j; + +} diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -732,6 +732,9 @@ break; case Stmt::BuiltinBitCastExprClass: K = CXCursor_BuiltinBitCastExpr; + break; + case Stmt::TransformExecutableDirectiveClass: + llvm_unreachable("not implemented"); } CXCursor C = { K, 0, { Parent, S, TU } };