diff --git a/clang/include/clang/AST/StmtTransform.h b/clang/include/clang/AST/StmtTransform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/AST/StmtTransform.h @@ -0,0 +1,52 @@ +//===--- StmtTransform.h - Code transformation AST nodes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transformation directive statement and clauses for the AST. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_AST_STMTTRANSFROM_H +#define LLVM_CLANG_AST_STMTTRANSFROM_H + +#include "clang/AST/Stmt.h" +#include "clang/Basic/Transform.h" +#include "llvm/Support/raw_ostream.h" + +namespace clang { + +/// Represents a clause of a \p TransformExecutableDirective. +class TransformClause { +public: + enum Kind { + UnknownKind, +#define TRANSFORM_CLAUSE(Keyword, Name) Name##Kind, +#define TRANSFORM_CLAUSE_LAST(Keyword, Name) Name##Kind, LastKind = Name##Kind +#include "clang/AST/TransformClauseKinds.def" + }; + + static bool isValidForTransform(Transform::Kind TransformKind, + TransformClause::Kind ClauseKind); + static Kind getClauseKind(Transform::Kind TransformKind, llvm::StringRef Str); + static llvm::StringRef getClauseKeyword(TransformClause::Kind ClauseKind); + + // TODO: implement +}; + +/// Represents +/// +/// #pragma clang transform +/// +/// in the AST. +class TransformExecutableDirective final { + // TODO: implement +}; + +const Stmt *getAssociatedLoop(const Stmt *S); +} // namespace clang + +#endif /* LLVM_CLANG_AST_STMTTRANSFROM_H */ diff --git a/clang/include/clang/AST/TransformClauseKinds.def b/clang/include/clang/AST/TransformClauseKinds.def new file mode 100644 --- /dev/null +++ b/clang/include/clang/AST/TransformClauseKinds.def @@ -0,0 +1,16 @@ + +#ifndef TRANSFORM_CLAUSE +# define TRANSFORM_CLAUSE(Keyword, Name) +#endif +#ifndef TRANSFORM_CLAUSE_LAST +# define TRANSFORM_CLAUSE_LAST(Keyword, Name) TRANSFORM_CLAUSE(Keyword, Name) +#endif + +TRANSFORM_CLAUSE(full,Full) +TRANSFORM_CLAUSE(partial,Partial) + +TRANSFORM_CLAUSE(width,Width) +TRANSFORM_CLAUSE_LAST(factor,Factor) + +#undef TRANSFORM_CLAUSE +#undef TRANSFORM_CLAUSE_LAST diff --git a/clang/include/clang/Basic/DiagnosticGroups.td b/clang/include/clang/Basic/DiagnosticGroups.td --- a/clang/include/clang/Basic/DiagnosticGroups.td +++ b/clang/include/clang/Basic/DiagnosticGroups.td @@ -1128,3 +1128,6 @@ def CTADMaybeUnsupported : DiagGroup<"ctad-maybe-unsupported">; def FortifySource : DiagGroup<"fortify-source">; + +// Warnings for #pragma clang transform +def ClangTransform : DiagGroup<"pragma-transform">; diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1245,6 +1245,18 @@ "vectorize_width, interleave, interleave_count, unroll, unroll_count, " "pipeline, pipeline_initiation_interval, vectorize_predicate, or distribute">; +// Pragma transform support. +def err_pragma_transform_expected_directive : Error< + "expected a transformation name">; +def err_pragma_transform_unknown_directive : Error< + "unknown transformation">; +def err_pragma_transform_expected_loop : Error< + "expected loop after transformation pragma">; +def err_pragma_transform_expected_clause : Error< + "expected a clause name">; +def err_pragma_transform_unknown_clause : Error< + "unknown clause name">; + def err_pragma_fp_invalid_option : Error< "%select{invalid|missing}0 option%select{ %1|}0; expected contract">; def err_pragma_fp_invalid_argument : Error< diff --git a/clang/include/clang/Basic/Transform.h b/clang/include/clang/Basic/Transform.h new file mode 100644 --- /dev/null +++ b/clang/include/clang/Basic/Transform.h @@ -0,0 +1,388 @@ +//===--- Transform.h - Code transformation classes --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines classes used for code transformations such as +// #pragma clang transform ... +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_BASIC_TRANSFORM_H +#define LLVM_CLANG_BASIC_TRANSFORM_H + +#include "clang/AST/Stmt.h" +#include "llvm/ADT/StringRef.h" + +namespace clang { + +class Transform { +public: + enum Kind { + UnknownKind, +#define TRANSFORM_DIRECTIVE(Name) Name##Kind, +#define TRANSFORM_DIRECTIVE_LAST(Name) \ + TRANSFORM_DIRECTIVE(Name) \ + LastKind = Name##Kind +#include "TransformKinds.def" + }; + + static Kind getTransformDirectiveKind(llvm::StringRef Str); + static llvm::StringRef getTransformDirectiveKeyword(Kind K); + +private: + Kind TransformKind; + SourceRange LocRange; + +protected: + Transform(Kind K, SourceRange LocRange) + : TransformKind(K), LocRange(LocRange) {} + +public: + virtual ~Transform() {} + + Kind getKind() const { return TransformKind; } + static bool classof(const Transform *Trans) { return true; } + + /// Source location of the code transformation directive. + /// @{ + SourceRange getRange() const { return LocRange; } + SourceLocation getBeginLoc() const { return LocRange.getBegin(); } + SourceLocation getEndLoc() const { return LocRange.getEnd(); } + void setRange(SourceRange L) { LocRange = L; } + void setRange(SourceLocation BeginLoc, SourceLocation EndLoc) { + LocRange = SourceRange(BeginLoc, EndLoc); + } + /// @} + + /// Each transformation defines how many loops it consumes and generates. + /// Users of this class can store arrays holding the information regarding the + /// loops, such as pointer to the AST node or the loop name. The index in this + /// array is its "role". + /// @{ + virtual int getNumInputs() const { return 1; } + virtual int getNumFollowups() const { return 0; } + /// @} + + /// A meta role may apply to multiple output loops, its attributes are added + /// to each of them. A typical example is the 'all' followup which applies to + /// all loops emitted by a transformation. The "all" follow-up role is a meta + /// output whose' attributes are added to all generated loops. + bool isMetaRole(int R) const { return R == 0; } + + /// Used to warn users that the current LLVM pass pipeline cannot apply + /// arbitrary transformation orders yet. + int getLoopPipelineStage() const; +}; + +/// Partially or fully unroll a loop. +/// +/// A full unroll transforms a loop such as +/// +/// for (int i = 0; i < 2; i+=1) +/// Stmt(i); +/// +/// into +/// +/// { +/// Stmt(0); +/// Stmt(1); +/// } +/// +/// Partial unrolling can also be applied when the loop trip count is only known +/// at runtime. For instance, partial unrolling by a factor of 2 transforms +/// +/// for (int i = 0; i < N; i+=1) +/// Stmt(i); +/// +/// into +/// +/// int i = 0; +/// for (; i < N; i+=2) { // unrolled +/// Stmt(i); +/// Stmt(i+1); +/// } +/// for (; i < N; i+=1) // epilogue/remainder +/// Stmt(i); +/// +/// LLVM's LoopUnroll pass uses the name runtime unrolling if N is not a +/// constant. +/// +/// When using heuristic unrolling, the optimizer decides itself whether to +/// unroll fully or partially. Because the front-end does not know what the +/// optimizer will do, there is no followup loop. Note that this is different to +/// partial unrolling with an undefined factor, which has always has followup +/// loops but may not be executed. +class LoopUnrollTransform final : public Transform { +private: + int64_t Factor; + + LoopUnrollTransform(SourceRange Loc, int64_t Factor) + : Transform(LoopUnrollKind, Loc), Factor(Factor) { + assert(Factor >= 2); + } + +public: + static bool classof(const LoopUnrollTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollKind; + } + + /// Create an instance of partial unrolling. The unroll factor must be at + /// least 2 or -1. When -1, the unroll factor can be chosen by the optimizer. + /// An unroll factor of 0 or 1 is not valid. + static LoopUnrollTransform *createPartial(SourceRange Loc, + int64_t Factor = -1) { + assert(Factor >= 2 || Factor == -1); + LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, Factor); + assert(Instance->isPartial()); + return Instance; + } + + static LoopUnrollTransform *createFull(SourceRange Loc) { + LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, -2); + assert(Instance->isFull()); + return Instance; + } + + static LoopUnrollTransform *createHeuristic(SourceRange Loc) { + LoopUnrollTransform *Instance = new LoopUnrollTransform(Loc, -3); + assert(Instance->isHeuristic()); + return Instance; + } + + bool isPartial() const { return Factor >= 2 || Factor == -1; } + bool isFull() const { return Factor == -2; } + bool isHeuristic() const { return Factor == -3; } + + enum Input { InputToUnroll }; + int getNumInputs() const override { return 1; } + + enum Followup { + FollowupAll, + FollowupUnrolled, // only for partial unrolling + FollowupRemainder // only for partial unrolling + }; + int getNumFollowups() const override { + if (isPartial()) + return 3; + return 0; + } + + int64_t getFactor() const { return Factor; } +}; + +/// Apply partial unroll-and-jam to a loop. +/// +/// That is, with a unroll factor of 2, transform +/// +/// for (int i = 0; i < N; i+=1) +/// for (int j = 0; j < M; j+=1) +/// Stmt(i,j); +/// +/// into +/// +/// int i = 0; +/// for (; i < N; i+=2) { // inner +/// for (int j = 0; j < M; j+=1) { // outer +/// Stmt(i,j); +/// Stmt(i+1,j); +/// } +/// for (; i < N; i+=1) // remainder/epilogue +/// for (int j = 0; j < M; j+=1) +/// Stmt(i,j); +/// +/// Note that LLVM's LoopUnrollAndJam pass does not support full unroll. +class LoopUnrollAndJamTransform final : public Transform { +private: + int64_t Factor; + + LoopUnrollAndJamTransform(SourceRange Loc, int64_t Factor) + : Transform(LoopUnrollAndJamKind, Loc), Factor(Factor) {} + +public: + static bool classof(const LoopUnrollAndJamTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopUnrollAndJamKind; + } + + /// Create an instance of unroll-and-jam. The unroll factor must be at least 2 + /// or -1. When -1, the unroll factor can be chosen by the optimizer. An + /// unroll factor of 0 or 1 is not valid. + static LoopUnrollAndJamTransform *createPartial(SourceRange Loc, + int64_t Factor = -1) { + assert(Factor >= 2 || Factor == -1); + LoopUnrollAndJamTransform *Instance = + new LoopUnrollAndJamTransform(Loc, Factor); + assert(Instance->isPartial()); + return Instance; + } + + static LoopUnrollAndJamTransform *createHeuristic(SourceRange Loc) { + LoopUnrollAndJamTransform *Instance = + new LoopUnrollAndJamTransform(Loc, -3); + assert(Instance->isHeuristic()); + return Instance; + } + + bool isPartial() const { return Factor >= 2 || Factor == -1; } + bool isHeuristic() const { return Factor == -3; } + + enum Input { InputOuter, InputInner }; + int getNumInputs() const override { return 2; } + + enum Followup { FollowupAll, FollowupOuter, FollowupInner }; + int getNumFollowups() const override { + if (isPartial()) + return 3; + return 0; + } + + int64_t getFactor() const { return Factor; } +}; + +/// Apply loop distribution (aka fission) to a loop. +/// +/// For example, transform the loop +/// +/// for (int i = 0; i < N; i+=1) { +/// StmtA(i); +/// StmtB(i); +/// } +/// +/// into +/// +/// for (int i = 0; i < N; i+=1) +/// StmtA(i); +/// for (int i = 0; i < N; i+=1) +/// StmtB(i); +/// +/// LLVM's LoopDistribute pass does not allow to control how the loop is +/// distributed. Hence, there are no non-meta followups. +class LoopDistributionTransform final : public Transform { +private: + LoopDistributionTransform(SourceRange Loc) + : Transform(LoopDistributionKind, Loc) {} + +public: + static bool classof(const LoopDistributionTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopDistributionKind; + } + + static LoopDistributionTransform *create(SourceRange Loc) { + return new LoopDistributionTransform(Loc); + } + + enum Input { InputToDistribute }; + int getNumInputs() const override { return 1; } + + enum Followup { FollowupAll }; + int getNumFollowups() const override { return 1; } +}; + +/// Vectorize a loop by executing multiple loop iterations at the same time in +/// vector lanes. +/// +/// For example, transform +/// +/// for (int i = 0; i < N; i+=1) +/// Stmt(i); +/// +/// into +/// +/// int i = 0; +/// for (; i < N; i+=2) // vectorized +/// Stmt(i:i+1); +/// for (; i < N; i+=1) // epilogue/remainder +/// Stmt(i); +class LoopVectorizationTransform final : public Transform { +private: + int64_t VectorizeWidth; + + LoopVectorizationTransform(SourceRange Loc, int64_t VectorizeWidth) + : Transform(LoopVectorizationKind, Loc), VectorizeWidth(VectorizeWidth) { + assert(VectorizeWidth >= 2); + } + +public: + static bool classof(const LoopVectorizationTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopVectorizationKind; + } + + static LoopVectorizationTransform *create(SourceRange Loc, + int64_t VectorizeWidth = -1) { + assert(VectorizeWidth >= 2 || VectorizeWidth == -1); + return new LoopVectorizationTransform(Loc, VectorizeWidth); + } + + enum Input { InputToVectorize }; + int getNumInputs() const override { return 1; } + + enum Followup { FollowupAll, FollowupVectorized, FollowupEpilogue }; + int getNumFollowups() const override { return 3; } + + int64_t getWidth() const { return VectorizeWidth; } +}; + +/// Execute multiple loop iterations at once by duplicating instructions. This +/// is different from unrolling in that it copies each instruction n times +/// instead of the entire loop body as loop unrolling does. +/// +/// For example, transform +/// +/// for (int i = 0; i < N; i+=1) { +/// InstA(i); +/// InstB(i); +/// InstC(i); +/// } +/// +/// into +/// +/// int i = 0; +/// for (; i < N; i+=2) { // interleaved +/// InstA(i); +/// InstA(i+1); +/// InstB(i); +/// InstB(i+1); +/// InstC(i); +/// InstC(i+1); +/// } +/// for (; i < N; i+=1) // epilogue/remainder +/// InstA(i); +/// InstB(i); +/// InstC(i); +/// } +class LoopInterleavingTransform final : public Transform { +private: + int64_t Factor; + + LoopInterleavingTransform(SourceRange Loc, int64_t Factor) + : Transform(LoopInterleavingKind, Loc), Factor(Factor) {} + +public: + static bool classof(const LoopInterleavingTransform *Trans) { return true; } + static bool classof(const Transform *Trans) { + return Trans->getKind() == LoopInterleavingKind; + } + + static LoopInterleavingTransform *create(SourceRange Loc, int64_t Factor) { + assert(Factor == -1 || Factor >= 2); + return new LoopInterleavingTransform(Loc, Factor); + } + + enum Input { InputToVectorize }; + int getNumInputs() const override { return 1; } + + enum Followup { FollowupAll, FollowupInterleaved, FollowupEpilogue }; + int getNumFollowups() const override { return 3; } + + int64_t getFactor() const { return Factor; } +}; + +} // namespace clang +#endif /* LLVM_CLANG_BASIC_TRANSFORM_H */ diff --git a/clang/include/clang/Basic/TransformKinds.def b/clang/include/clang/Basic/TransformKinds.def new file mode 100644 --- /dev/null +++ b/clang/include/clang/Basic/TransformKinds.def @@ -0,0 +1,18 @@ + +#ifndef TRANSFORM_DIRECTIVE +# define TRANSFORM_DIRECTIVE(Name) +#endif +#ifndef TRANSFORM_DIRECTIVE_LAST +# define TRANSFORM_DIRECTIVE_LAST(Name) TRANSFORM_DIRECTIVE(Name) +#endif + +// Loop transformations accessible through "#pragma clang transform". +TRANSFORM_DIRECTIVE(LoopUnroll) +TRANSFORM_DIRECTIVE(LoopUnrollAndJam) +TRANSFORM_DIRECTIVE(LoopDistribution) +TRANSFORM_DIRECTIVE(LoopVectorization) +TRANSFORM_DIRECTIVE_LAST(LoopInterleaving) + + +#undef TRANSFORM_DIRECTIVE +#undef TRANSFORM_DIRECTIVE_LAST diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -1647,6 +1647,17 @@ IsTypeCast }; + using TransformClauseResult = ActionResult; + static TransformClauseResult ClauseError() { + return TransformClauseResult(true); + } + static TransformClauseResult ClauseError(const DiagnosticBuilder &) { + return ClauseError(); + } + static TransformClauseResult ClauseEmpty() { + return TransformClauseResult(false); + } + ExprResult ParseExpression(TypeCastState isTypeCast = NotTypeCast); ExprResult ParseConstantExpressionInExprEvalContext( TypeCastState isTypeCast = NotTypeCast); @@ -1983,6 +1994,12 @@ SourceLocation *TrailingElseLoc, ParsedAttributesWithRange &Attrs); + Transform::Kind + tryParsePragmaTransform(SourceLocation BeginLoc, ParsedStmtContext StmtCtx, + SmallVectorImpl &Clauses); + StmtResult ParsePragmaTransform(ParsedStmtContext StmtCtx); + TransformClauseResult ParseTransformClause(Transform::Kind TransformKind); + /// Describes the behavior that should be taken for an __if_exists /// block. enum IfExistsBehavior { diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -29,6 +29,7 @@ #include "clang/AST/NSAPI.h" #include "clang/AST/PrettyPrinter.h" #include "clang/AST/StmtCXX.h" +#include "clang/AST/StmtTransform.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/ExpressionTraits.h" @@ -11749,6 +11750,16 @@ ConstructorDestructor, BuiltinFunction }; + + StmtResult + ActOnLoopTransformDirective(Transform::Kind Kind, + llvm::ArrayRef Clauses, + Stmt *AStmt, SourceRange Loc); + + TransformClause *ActOnFullClause(SourceRange Loc); + TransformClause *ActOnPartialClause(SourceRange Loc, Expr *Factor); + TransformClause *ActOnWidthClause(SourceRange Loc, Expr *Width); + TransformClause *ActOnFactorClause(SourceRange Loc, Expr *Factor); }; /// RAII object that enters a new expression evaluation context. diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt --- a/clang/lib/AST/CMakeLists.txt +++ b/clang/lib/AST/CMakeLists.txt @@ -100,6 +100,7 @@ StmtOpenMP.cpp StmtPrinter.cpp StmtProfile.cpp + StmtTransform.cpp StmtViz.cpp TemplateBase.cpp TemplateName.cpp diff --git a/clang/lib/AST/StmtTransform.cpp b/clang/lib/AST/StmtTransform.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/AST/StmtTransform.cpp @@ -0,0 +1,69 @@ +//===--- StmtTransform.h - Code transformation AST nodes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transformation directive statement and clauses for the AST. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/StmtTransform.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Stmt.h" +#include "clang/AST/StmtOpenMP.h" + +using namespace clang; + +bool TransformClause::isValidForTransform(Transform::Kind TransformKind, + TransformClause::Kind ClauseKind) { + switch (TransformKind) { + case clang::Transform::LoopUnrollKind: + return ClauseKind == PartialKind || ClauseKind == FullKind; + case clang::Transform::LoopUnrollAndJamKind: + return ClauseKind == PartialKind; + case clang::Transform::LoopVectorizationKind: + return ClauseKind == WidthKind; + case clang::Transform::LoopInterleavingKind: + return ClauseKind == FactorKind; + default: + return false; + } +} + +TransformClause::Kind +TransformClause ::getClauseKind(Transform::Kind TransformKind, + llvm::StringRef Str) { +#define TRANSFORM_CLAUSE(Keyword, Name) \ + if (isValidForTransform(TransformKind, TransformClause::Kind::Name##Kind) && \ + Str == #Keyword) \ + return TransformClause::Kind::Name##Kind; +#include "clang/AST/TransformClauseKinds.def" + return TransformClause::UnknownKind; +} + +llvm::StringRef +TransformClause ::getClauseKeyword(TransformClause::Kind ClauseKind) { + assert(ClauseKind > UnknownKind); + assert(ClauseKind <= LastKind); + static const char *ClauseKeyword[LastKind] = { +#define TRANSFORM_CLAUSE(Keyword, Name) #Keyword, +#include "clang/AST/TransformClauseKinds.def" + + }; + return ClauseKeyword[ClauseKind - 1]; +} + +const Stmt *clang::getAssociatedLoop(const Stmt *S) { + switch (S->getStmtClass()) { + case Stmt::ForStmtClass: + case Stmt::WhileStmtClass: + case Stmt::DoStmtClass: + case Stmt::CXXForRangeStmtClass: + return S; + default: + return nullptr; + } +} diff --git a/clang/lib/Basic/CMakeLists.txt b/clang/lib/Basic/CMakeLists.txt --- a/clang/lib/Basic/CMakeLists.txt +++ b/clang/lib/Basic/CMakeLists.txt @@ -87,6 +87,7 @@ Targets/X86.cpp Targets/XCore.cpp TokenKinds.cpp + Transform.cpp Version.cpp Warnings.cpp XRayInstr.cpp diff --git a/clang/lib/Basic/Transform.cpp b/clang/lib/Basic/Transform.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Basic/Transform.cpp @@ -0,0 +1,62 @@ +//===--- Transform.h - Code transformation classes --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines classes used for code transformations such as +// #pragma clang transform ... +// +//===----------------------------------------------------------------------===// + +#include "clang/Basic/Transform.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" + +using namespace clang; + +Transform::Kind Transform ::getTransformDirectiveKind(llvm::StringRef Str) { + return llvm::StringSwitch(Str) + .Case("unroll", LoopUnrollKind) + .Case("unrollandjam", LoopUnrollAndJamKind) + .Case("vectorize", LoopVectorizationKind) + .Case("interleave", LoopInterleavingKind) + .Case("distribute", LoopDistributionKind) + .Default(UnknownKind); +} + +llvm::StringRef Transform ::getTransformDirectiveKeyword(Kind K) { + switch (K) { + case UnknownKind: + break; + case LoopUnrollKind: + return "unroll"; + case LoopUnrollAndJamKind: + return "unrollandjam"; + case LoopVectorizationKind: + return "vectorize"; + case LoopInterleavingKind: + return "interleave"; + case LoopDistributionKind: + return "distribute"; + } + llvm_unreachable("Not a known transformation"); +} + +int Transform::getLoopPipelineStage() const { + switch (getKind()) { + case Transform::Kind::LoopUnrollKind: + return cast(this)->isFull() ? 0 : 4; + case Transform::Kind::LoopDistributionKind: + return 1; + case Transform::Kind::LoopInterleavingKind: + case Transform::Kind::LoopVectorizationKind: + return 2; + case Transform::Kind::LoopUnrollAndJamKind: + return 3; + default: + return -1; + } +} diff --git a/clang/lib/Parse/CMakeLists.txt b/clang/lib/Parse/CMakeLists.txt --- a/clang/lib/Parse/CMakeLists.txt +++ b/clang/lib/Parse/CMakeLists.txt @@ -20,6 +20,7 @@ ParseStmtAsm.cpp ParseTemplate.cpp ParseTentative.cpp + ParseTransform.cpp Parser.cpp LINK_LIBS diff --git a/clang/lib/Parse/ParseStmt.cpp b/clang/lib/Parse/ParseStmt.cpp --- a/clang/lib/Parse/ParseStmt.cpp +++ b/clang/lib/Parse/ParseStmt.cpp @@ -14,6 +14,7 @@ #include "clang/AST/PrettyDeclStackTrace.h" #include "clang/Basic/Attributes.h" #include "clang/Basic/PrettyStackTrace.h" +#include "clang/Basic/Transform.h" #include "clang/Parse/LoopHint.h" #include "clang/Parse/Parser.h" #include "clang/Parse/RAIIObjectsForParser.h" @@ -400,6 +401,10 @@ ProhibitAttributes(Attrs); return ParsePragmaLoopHint(Stmts, StmtCtx, TrailingElseLoc, Attrs); + case tok::annot_pragma_transform: + ProhibitAttributes(Attrs); + return ParsePragmaTransform(StmtCtx); + case tok::annot_pragma_dump: HandlePragmaDump(); return StmtEmpty(); diff --git a/clang/lib/Parse/ParseTransform.cpp b/clang/lib/Parse/ParseTransform.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Parse/ParseTransform.cpp @@ -0,0 +1,145 @@ +//===---- ParseTransform.h -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parse #pragma clang transform ... +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/StmtTransform.h" +#include "clang/Parse/Parser.h" +#include "clang/Parse/RAIIObjectsForParser.h" + +using namespace clang; + +Transform::Kind +Parser::tryParsePragmaTransform(SourceLocation BeginLoc, + ParsedStmtContext StmtCtx, + SmallVectorImpl &Clauses) { + // ... Tok= | <...> tok::annot_pragma_transform_end ... + if (Tok.isNot(tok::identifier)) { + Diag(Tok, diag::err_pragma_transform_expected_directive); + return Transform::UnknownKind; + } + std::string DirectiveStr = PP.getSpelling(Tok); + Transform::Kind DirectiveKind = + Transform::getTransformDirectiveKind(DirectiveStr); + ConsumeToken(); + + switch (DirectiveKind) { + case Transform::LoopUnrollKind: + case Transform::LoopUnrollAndJamKind: + case Transform::LoopDistributionKind: + case Transform::LoopVectorizationKind: + case Transform::LoopInterleavingKind: + break; + default: + Diag(Tok, diag::err_pragma_transform_unknown_directive); + return Transform::UnknownKind; + } + + while (true) { + TransformClauseResult Clause = ParseTransformClause(DirectiveKind); + if (Clause.isInvalid()) + return Transform::UnknownKind; + if (!Clause.isUsable()) + break; + + Clauses.push_back(Clause.get()); + } + + assert(Tok.is(tok::annot_pragma_transform_end)); + return DirectiveKind; +} + +StmtResult Parser::ParsePragmaTransform(ParsedStmtContext StmtCtx) { + assert(Tok.is(tok::annot_pragma_transform) && "Not a transform directive!"); + + // ... Tok=annot_pragma_transform | <...> annot_pragma_transform_end + // ... + SourceLocation BeginLoc = ConsumeAnnotationToken(); + + ParenBraceBracketBalancer BalancerRAIIObj(*this); + + SmallVector DirectiveClauses; + Transform::Kind DirectiveKind = + tryParsePragmaTransform(BeginLoc, StmtCtx, DirectiveClauses); + if (DirectiveKind == Transform::UnknownKind) { + SkipUntil(tok::annot_pragma_transform_end); + return StmtError(); + } + + assert(Tok.is(tok::annot_pragma_transform_end)); + SourceLocation EndLoc = ConsumeAnnotationToken(); + + SourceLocation PreStmtLoc = Tok.getLocation(); + StmtResult AssociatedStmt = ParseStatement(); + if (AssociatedStmt.isInvalid()) + return AssociatedStmt; + if (!getAssociatedLoop(AssociatedStmt.get())) + return StmtError( + Diag(PreStmtLoc, diag::err_pragma_transform_expected_loop)); + + return Actions.ActOnLoopTransformDirective(DirectiveKind, DirectiveClauses, + AssociatedStmt.get(), + {BeginLoc, EndLoc}); +} + +Parser::TransformClauseResult +Parser::ParseTransformClause(Transform::Kind TransformKind) { + // No more clauses + if (Tok.is(tok::annot_pragma_transform_end)) + return ClauseEmpty(); + + SourceLocation StartLoc = Tok.getLocation(); + if (Tok.isNot(tok::identifier)) + return ClauseError(Diag(Tok, diag::err_pragma_transform_expected_clause)); + std::string ClauseKeyword = PP.getSpelling(Tok); + ConsumeToken(); + TransformClause::Kind Kind = + TransformClause::getClauseKind(TransformKind, ClauseKeyword); + + switch (Kind) { + case TransformClause::UnknownKind: + return ClauseError(Diag(Tok, diag::err_pragma_transform_unknown_clause)); + + // Clauses without arguments. + case TransformClause::FullKind: + return Actions.ActOnFullClause(SourceRange{StartLoc, StartLoc}); + + // Clauses with integer argument. + case TransformClause::PartialKind: + case TransformClause::WidthKind: + case TransformClause::FactorKind: { + BalancedDelimiterTracker T(*this, tok::l_paren, + tok::annot_pragma_transform_end); + if (T.expectAndConsume(diag::err_expected_lparen_after, + ClauseKeyword.data())) + return ClauseError(); + + ExprResult Expr = ParseConstantExpression(); + if (Expr.isInvalid()) + return ClauseError(); + + if (T.consumeClose()) + return ClauseError(); + SourceLocation EndLoc = T.getCloseLocation(); + SourceRange Range{StartLoc, EndLoc}; + switch (Kind) { + case TransformClause::PartialKind: + return Actions.ActOnPartialClause(Range, Expr.get()); + case TransformClause::WidthKind: + return Actions.ActOnWidthClause(Range, Expr.get()); + case TransformClause::FactorKind: + return Actions.ActOnFactorClause(Range, Expr.get()); + default: + llvm_unreachable("Unhandled clause"); + } + } + } + llvm_unreachable("Unhandled clause"); +} diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -63,6 +63,7 @@ SemaTemplateInstantiate.cpp SemaTemplateInstantiateDecl.cpp SemaTemplateVariadic.cpp + SemaTransform.cpp SemaType.cpp TypeLocBuilder.cpp diff --git a/clang/lib/Sema/SemaTransform.cpp b/clang/lib/Sema/SemaTransform.cpp new file mode 100644 --- /dev/null +++ b/clang/lib/Sema/SemaTransform.cpp @@ -0,0 +1,49 @@ +//===---- SemaTransform.h ------------------------------------- -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Semantic analysis for code transformations. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/StmtTransform.h" +#include "clang/Basic/Transform.h" +#include "clang/Sema/Sema.h" +#include "clang/Sema/SemaDiagnostic.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringMap.h" + +using namespace clang; + +StmtResult +Sema::ActOnLoopTransformDirective(Transform::Kind Kind, + llvm::ArrayRef Clauses, + Stmt *AStmt, SourceRange Loc) { + // TOOD: implement + return StmtError(); +} + +TransformClause *Sema::ActOnFullClause(SourceRange Loc) { + // TOOD: implement + return nullptr; +} + +TransformClause *Sema::ActOnPartialClause(SourceRange Loc, Expr *Factor) { + // TOOD: implement + return nullptr; +} + +TransformClause *Sema::ActOnWidthClause(SourceRange Loc, Expr *Width) { + // TOOD: implement + return nullptr; +} + +TransformClause *Sema::ActOnFactorClause(SourceRange Loc, Expr *Factor) { + // TOOD: implement + return nullptr; +} diff --git a/clang/test/Parser/pragma-transform.cpp b/clang/test/Parser/pragma-transform.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Parser/pragma-transform.cpp @@ -0,0 +1,92 @@ +// RUN: %clang_cc1 -std=c++11 -fexperimental-transform-pragma -verify %s + +void pragma_transform(int *List, int Length) { +// FIXME: This does not emit an error +#pragma clang + +/* expected-error@+1 {{expected a transformation name}} */ +#pragma clang transform + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{unknown transformation}} */ +#pragma clang transform unknown_transformation + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+2 {{expected loop after transformation pragma}} */ +#pragma clang transform unroll + pragma_transform(List, Length); + +/* expected-error@+1 {{unknown clause name}} */ +#pragma clang transform unroll unknown_clause + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{expected '(' after 'partial'}} */ +#pragma clang transform unroll partial + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{expected expression}} */ +#pragma clang transform unroll partial( + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{expected '(' after 'partial'}} */ +#pragma clang transform unroll partial) + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+2 {{expected ')'}} */ +/* expected-note@+1 {{to match this '('}} */ +#pragma clang transform unroll partial(4 + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{expected expression}} */ +#pragma clang transform unroll partial() + for (int i = 0; i < Length; i+=1) + List[i] = i; + +/* expected-error@+1 {{use of undeclared identifier 'badvalue'}} */ +#pragma clang transform unroll partial(badvalue) + for (int i = 0; i < Length; i+=1) + List[i] = i; + + { +/* expected-error@+2 {{expected statement}} */ +#pragma clang transform unroll + } +} + +/* expected-error@+1 {{expected unqualified-id}} */ +#pragma clang transform unroll +int I; + +/* expected-error@+1 {{expected unqualified-id}} */ +#pragma clang transform unroll +void func(); + +class C1 { +/* expected-error@+3 {{this pragma cannot appear in class declaration}} */ +/* expected-error@+2 {{expected member name or ';' after declaration specifiers}} */ +/* expected-error@+1 {{unknown type name 'unroll'}} */ +#pragma clang transform unroll +}; + +template +void pragma_transform_template_func(int *List, int Length) { +#pragma clang transform unroll partial(F) + for (int i = 0; i < Length; i+=1) + List[i] = i; +} + +template +class C2 { + void pragma_transform_template_class(int *List, int Length) { +#pragma clang transform unroll partial(F) + for (int i = 0; i < Length; i+=1) + List[i] = i; + } +};