diff --git a/mlir/include/mlir/Tools/PDLL/AST/Context.h b/mlir/include/mlir/Tools/PDLL/AST/Context.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/AST/Context.h @@ -0,0 +1,52 @@ +//===- Context.h - PDLL AST Context -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_AST_CONTEXT_H_ +#define MLIR_TOOLS_PDLL_AST_CONTEXT_H_ + +#include "mlir/Support/StorageUniquer.h" +#include "mlir/Tools/PDLL/AST/Diagnostic.h" + +namespace mlir { +namespace pdll { +namespace ast { +/// This class represents the main context of the PDLL AST. It handles +/// allocating all of the AST constructs, and manages all state necessary for +/// the AST. +class Context { +public: + Context(); + Context(const Context &) = delete; + Context &operator=(const Context &) = delete; + + /// Return the allocator owned by this context. + llvm::BumpPtrAllocator &getAllocator() { return allocator; } + + /// Return the storage uniquer used for AST types. + StorageUniquer &getTypeUniquer() { return typeUniquer; } + + /// Return the diagnostic engine of this context. + DiagnosticEngine &getDiagEngine() { return diagEngine; } + +private: + /// The diagnostic engine of this AST context. + DiagnosticEngine diagEngine; + + /// The allocator used for AST nodes, and other entities allocated within the + /// context. + llvm::BumpPtrAllocator allocator; + + /// The uniquer used for creating AST types. + StorageUniquer typeUniquer; +}; + +} // namespace ast +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_AST_CONTEXT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/AST/Diagnostic.h b/mlir/include/mlir/Tools/PDLL/AST/Diagnostic.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/AST/Diagnostic.h @@ -0,0 +1,182 @@ +//===- Diagnostic.h - PDLL AST Diagnostics ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_ +#define MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace pdll { +namespace ast { +class DiagnosticEngine; + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +/// This class provides a simple implementation of a PDLL diagnostic. +class Diagnostic { +public: + using Severity = llvm::SourceMgr::DiagKind; + + /// Return the severity of this diagnostic. + Severity getSeverity() const { return severity; } + + /// Return the message of this diagnostic. + StringRef getMessage() const { return message; } + + /// Return the location of this diagnostic. + llvm::SMRange getLocation() const { return location; } + + /// Return the notes of this diagnostic. + auto getNotes() const { return llvm::make_pointee_range(notes); } + + /// Attach a note to this diagnostic. + Diagnostic &attachNote(const Twine &msg, + Optional noteLoc = llvm::None) { + assert(getSeverity() != Severity::DK_Note && + "cannot attach a Note to a Note"); + notes.emplace_back( + new Diagnostic(Severity::DK_Note, noteLoc.getValueOr(location), msg)); + return *notes.back(); + } + + /// Allow an inflight diagnostic to be converted to 'failure', otherwise + /// 'success' if this is an empty diagnostic. + operator LogicalResult() const { return failure(); } + +private: + Diagnostic(Severity severity, llvm::SMRange loc, const Twine &msg) + : severity(severity), message(msg.str()), location(loc) {} + + // Allow access to the constructor. + friend DiagnosticEngine; + + /// The severity of this diagnostic. + Severity severity; + /// The message held by this diagnostic. + std::string message; + /// The raw location of this diagnostic. + llvm::SMRange location; + /// Any additional note diagnostics attached to this diagnostic. + std::vector> notes; +}; + +//===----------------------------------------------------------------------===// +// InFlightDiagnostic +//===----------------------------------------------------------------------===// + +/// This class represents a diagnostic that is inflight and set to be reported. +/// This allows for last minute modifications of the diagnostic before it is +/// emitted by a DiagnosticEngine. +class InFlightDiagnostic { +public: + InFlightDiagnostic() = default; + InFlightDiagnostic(InFlightDiagnostic &&rhs) + : owner(rhs.owner), impl(std::move(rhs.impl)) { + // Reset the rhs diagnostic. + rhs.impl.reset(); + rhs.abandon(); + } + ~InFlightDiagnostic() { + if (isInFlight()) + report(); + } + + /// Access the internal diagnostic. + Diagnostic &operator*() { return *impl; } + Diagnostic *operator->() { return &*impl; } + + /// Reports the diagnostic to the engine. + void report(); + + /// Abandons this diagnostic so that it will no longer be reported. + void abandon() { owner = nullptr; } + + /// Allow an inflight diagnostic to be converted to 'failure', otherwise + /// 'success' if this is an empty diagnostic. + operator LogicalResult() const { return failure(isActive()); } + +private: + InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete; + InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete; + InFlightDiagnostic(DiagnosticEngine *owner, Diagnostic &&rhs) + : owner(owner), impl(std::move(rhs)) {} + + /// Returns true if the diagnostic is still active, i.e. it has a live + /// diagnostic. + bool isActive() const { return impl.hasValue(); } + + /// Returns true if the diagnostic is still in flight to be reported. + bool isInFlight() const { return owner; } + + // Allow access to the constructor. + friend DiagnosticEngine; + + /// The engine that this diagnostic is to report to. + DiagnosticEngine *owner = nullptr; + + /// The raw diagnostic that is inflight to be reported. + Optional impl; +}; + +//===----------------------------------------------------------------------===// +// DiagnosticEngine +//===----------------------------------------------------------------------===// + +/// This class manages the construction and emission of PDLL diagnostics. +class DiagnosticEngine { +public: + /// A function used to handle diagnostics emitted by the engine. + using HandlerFn = llvm::unique_function; + + /// Emit an error to the diagnostic engine. + InFlightDiagnostic emitError(llvm::SMRange loc, const Twine &msg) { + return InFlightDiagnostic( + this, Diagnostic(Diagnostic::Severity::DK_Error, loc, msg)); + } + InFlightDiagnostic emitWarning(llvm::SMRange loc, const Twine &msg) { + return InFlightDiagnostic( + this, Diagnostic(Diagnostic::Severity::DK_Warning, loc, msg)); + } + + /// Report the given diagnostic. + void report(Diagnostic &&diagnostic) { + if (handler) + handler(diagnostic); + } + + /// Get the current handler function of this diagnostic engine. + const HandlerFn &getHandlerFn() const { return handler; } + + /// Take the current handler function, resetting the current handler to null. + HandlerFn takeHandlerFn() { + HandlerFn oldHandler = std::move(handler); + handler = {}; + return oldHandler; + } + + /// Set the handler function for this diagnostic engine. + void setHandlerFn(HandlerFn &&newHandler) { handler = std::move(newHandler); } + +private: + /// The registered diagnostic handler function. + HandlerFn handler; +}; + +} // namespace ast +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_AST_DIAGNOSTICS_H_ diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -0,0 +1,681 @@ +//===- Nodes.h --------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_ +#define MLIR_TOOLS_PDLL_AST_NODES_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/Tools/PDLL/AST/Types.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TrailingObjects.h" + +namespace mlir { +namespace pdll { +namespace ast { +class Context; +class Decl; +class Expr; +class OpNameDecl; +class VariableDecl; + +//===----------------------------------------------------------------------===// +// Name +//===----------------------------------------------------------------------===// + +/// This class provides a convenient API for interacting with source names. It +/// contains a string name as well as the source location for that name. +struct Name { + Name() = default; + Name(std::string &&name, llvm::SMRange location) + : name(std::move(name)), location(location) {} + Name(StringRef name, llvm::SMRange location) + : name(name.str()), location(location) {} + + /// The string name of the decl. + std::string name; + /// The location of the decl name. + llvm::SMRange location; +}; + +//===----------------------------------------------------------------------===// +// DeclScope +//===----------------------------------------------------------------------===// + +/// This class represents a scope for named AST decls. A scope determines the +/// visibility and lifetime of a named declaration. +class DeclScope { +public: + /// Create a new scope with an optional parent scope. + DeclScope(DeclScope *parent = nullptr) : parent(parent) {} + + /// Return the parent scope of this scope, or nullptr if there is no parent. + DeclScope *getParentScope() { return parent; } + const DeclScope *getParentScope() const { return parent; } + + /// Return all of the decls within this scope. + auto getDecls() const { return llvm::make_second_range(decls); } + + /// Add a new decl to the scope. + void add(Decl *decl); + + /// Lookup a decl with the given name starting from this scope. Returns + /// nullptr if no decl could be found. + Decl *lookup(StringRef name); + template + T *lookup(StringRef name) { + return dyn_cast_or_null(lookup(name)); + } + const Decl *lookup(StringRef name) const { + return const_cast(this)->lookup(name); + } + template + const T *lookup(StringRef name) const { + return dyn_cast_or_null(lookup(name)); + } + +private: + /// The parent scope, or null if this is a top-level scope. + DeclScope *parent; + /// The decls defined within this scope. + llvm::StringMap decls; +}; + +//===----------------------------------------------------------------------===// +// Node +//===----------------------------------------------------------------------===// + +/// This class represents a base AST node. All AST nodes are derived from this +/// class, and it contains many of the base functionality for interacting with +/// nodes. +class Node { +public: + /// This CRTP class provides several utilies when defining new AST nodes. + template + class NodeBase : public BaseT { + public: + using Base = NodeBase; + + /// Provide type casting support. + static bool classof(const Node *node) { + return node->getTypeID() == TypeID::get(); + } + + protected: + template + explicit NodeBase(llvm::SMRange loc, Args &&...args) + : BaseT(TypeID::get(), loc, std::forward(args)...) {} + }; + + /// Return the type identifier of this node. + TypeID getTypeID() const { return typeID; } + + /// Return the location of this node. + llvm::SMRange getLoc() const { return loc; } + + /// Print this node to the given stream. + void print(raw_ostream &os) const; + +protected: + Node(TypeID typeID, llvm::SMRange loc) : typeID(typeID), loc(loc) {} + +private: + /// A unique type identifier for this node. + TypeID typeID; + + /// The location of this node. + llvm::SMRange loc; +}; + +//===----------------------------------------------------------------------===// +// Stmt +//===----------------------------------------------------------------------===// + +/// This class represents a base AST Statement node. +class Stmt : public Node { +public: + using Node::Node; + + /// Provide type casting support. + static bool classof(const Node *node); +}; + +//===----------------------------------------------------------------------===// +// CompoundStmt +//===----------------------------------------------------------------------===// + +/// This statement represents a compound statement, which contains a collection +/// of other statements. +class CompoundStmt final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static CompoundStmt *create(Context &ctx, llvm::SMRange location, + ArrayRef children); + + /// Return the children of this compound statement. + MutableArrayRef getChildren() { + return llvm::makeMutableArrayRef(getTrailingObjects(), numChildren); + } + ArrayRef getChildren() const { + return const_cast(this)->getChildren(); + } + ArrayRef::iterator begin() const { return getChildren().begin(); } + ArrayRef::iterator end() const { return getChildren().end(); } + +private: + CompoundStmt(llvm::SMRange location, unsigned numChildren) + : Base(location), numChildren(numChildren) {} + + /// The number of held children statements. + unsigned numChildren; + + // Allow access to various privates. + friend class llvm::TrailingObjects; +}; + +//===----------------------------------------------------------------------===// +// LetStmt +//===----------------------------------------------------------------------===// + +/// This statement represents a `let` statement in PDLL. This statement is used +/// to define variables. +class LetStmt final : public Node::NodeBase { +public: + static LetStmt *create(Context &ctx, llvm::SMRange loc, + VariableDecl *varDecl); + + /// Return the variable defined by this statement. + VariableDecl *getVarDecl() const { return varDecl; } + +private: + LetStmt(llvm::SMRange loc, VariableDecl *varDecl) + : Base(loc), varDecl(varDecl) {} + + /// The variable defined by this statement. + VariableDecl *varDecl; +}; + +//===----------------------------------------------------------------------===// +// OpRewriteStmt +//===----------------------------------------------------------------------===// + +/// This class represents a base operation rewrite statement. Operation rewrite +/// statements perform a set of transformations on a given root operation. +class OpRewriteStmt : public Stmt { +public: + /// Provide type casting support. + static bool classof(const Node *node); + + /// Return the root operation of this rewrite. + Expr *getRootOpExpr() const { return rootOp; } + +protected: + OpRewriteStmt(TypeID typeID, llvm::SMRange loc, Expr *rootOp) + : Stmt(typeID, loc), rootOp(rootOp) {} + +protected: + /// The root operation being rewritten. + Expr *rootOp; +}; + +//===----------------------------------------------------------------------===// +// EraseStmt + +/// This statement represents the `erase` statement in PDLL. This statement +/// erases the given root operation, corresponding roughly to the +/// PatternRewriter::eraseOp API. +class EraseStmt final : public Node::NodeBase { +public: + static EraseStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp); + +private: + EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {} +}; + +//===----------------------------------------------------------------------===// +// Expr +//===----------------------------------------------------------------------===// + +/// This class represents a base AST Expression node. +class Expr : public Stmt { +public: + /// Return the type of this expression. + Type getType() const { return type; } + + /// Provide type casting support. + static bool classof(const Node *node); + +protected: + Expr(TypeID typeID, llvm::SMRange loc, Type type) + : Stmt(typeID, loc), type(type) {} + +private: + /// The type of this expression. + Type type; +}; + +//===----------------------------------------------------------------------===// +// DeclRefExpr +//===----------------------------------------------------------------------===// + +/// This expression represents a reference to a Decl node. +class DeclRefExpr : public Node::NodeBase { +public: + static DeclRefExpr *create(Context &ctx, llvm::SMRange loc, Decl *decl, + Type type); + + /// Get the decl referenced by this expression. + Decl *getDecl() const { return decl; } + +private: + DeclRefExpr(llvm::SMRange loc, Decl *decl, Type type) + : Base(loc, type), decl(decl) {} + + /// The decl referenced by this expression. + Decl *decl; +}; + +//===----------------------------------------------------------------------===// +// MemberAccessExpr +//===----------------------------------------------------------------------===// + +/// This expression represents a named member or field access of a given parent +/// expression. +class MemberAccessExpr : public Node::NodeBase { +public: + static MemberAccessExpr *create(Context &ctx, llvm::SMRange loc, + const Expr *parentExpr, StringRef memberName, + Type type); + + /// Get the parent expression of this access. + const Expr *getParentExpr() const { return parentExpr; } + + /// Return the name of the member being accessed. + StringRef getMemberName() const { return memberName; } + +private: + MemberAccessExpr(llvm::SMRange loc, const Expr *parentExpr, + StringRef memberName, Type type) + : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {} + + /// The parent expression of this access. + const Expr *parentExpr; + + /// The name of the member being accessed from the parent. + StringRef memberName; +}; + +//===----------------------------------------------------------------------===// +// Decl +//===----------------------------------------------------------------------===// + +/// This class represents the base Decl node. +class Decl : public Node { +public: + /// Return the name of the decl, or None if it doesn't have one. + const Optional &getName() const { return name; } + + /// Provide type casting support. + static bool classof(const Node *node); + +protected: + Decl(TypeID typeID, llvm::SMRange loc, Optional name = llvm::None) + : Node(typeID, loc), name(name) {} + +private: + /// The name of the decl. This is optional for some decls, such as + /// PatternDecl. + Optional name; +}; + +//===----------------------------------------------------------------------===// +// ConstraintDecl +//===----------------------------------------------------------------------===// + +/// This class represents the base of all AST Constraint decls. Constraints +/// apply matcher conditions to, and define the type of PDLL variables. +class ConstraintDecl : public Decl { +public: + /// Provide type casting support. + static bool classof(const Node *node); + +protected: + ConstraintDecl(TypeID typeID, llvm::SMRange loc, + Optional name = llvm::None) + : Decl(typeID, loc, name) {} +}; + +/// This class represents a reference to a constraint, and contains a constraint +/// and the location of the reference. +struct ConstraintRef { + ConstraintRef(const ConstraintDecl *constraint, llvm::SMRange refLoc) + : constraint(constraint), referenceLoc(refLoc) {} + explicit ConstraintRef(const ConstraintDecl *constraint) + : ConstraintRef(constraint, constraint->getLoc()) {} + + const ConstraintDecl *constraint; + llvm::SMRange referenceLoc; +}; + +//===----------------------------------------------------------------------===// +// CoreConstraintDecl +//===----------------------------------------------------------------------===// + +/// This class represents the base of all "core" constraints. Core constraints +/// are those that generally represent a concrete IR construct, such as +/// `Type`s or `Value`s. +class CoreConstraintDecl : public ConstraintDecl { +public: + /// Provide type casting support. + static bool classof(const Node *node); + +protected: + CoreConstraintDecl(TypeID typeID, llvm::SMRange loc, + Optional name = llvm::None) + : ConstraintDecl(typeID, loc, name) {} +}; + +//===----------------------------------------------------------------------===// +// AttrConstraintDecl + +/// The class represents an Attribute constraint, and constrains a variable to +/// be an Attribute. +class AttrConstraintDecl + : public Node::NodeBase { +public: + static AttrConstraintDecl *create(Context &ctx, llvm::SMRange loc, + Expr *typeExpr = nullptr); + + /// Return the optional type the attribute is constrained to. + Expr *getTypeExpr() { return typeExpr; } + const Expr *getTypeExpr() const { return typeExpr; } + +protected: + AttrConstraintDecl(llvm::SMRange loc, Expr *typeExpr) + : Base(loc), typeExpr(typeExpr) {} + + /// An optional type that the attribute is constrained to. + Expr *typeExpr; +}; + +//===----------------------------------------------------------------------===// +// OpConstraintDecl + +/// The class represents an Operation constraint, and constrains a variable to +/// be an Operation. +class OpConstraintDecl + : public Node::NodeBase { +public: + static OpConstraintDecl *create(Context &ctx, llvm::SMRange loc, + const OpNameDecl *nameDecl = nullptr); + + /// Return the name of the operation, or None if there isn't one. + Optional getName() const; + + /// Return the declaration of the operation name. + const OpNameDecl *getNameDecl() const { return nameDecl; } + +protected: + explicit OpConstraintDecl(llvm::SMRange loc, const OpNameDecl *nameDecl) + : Base(loc), nameDecl(nameDecl) {} + + /// The operation name of this constraint. + const OpNameDecl *nameDecl; +}; + +//===----------------------------------------------------------------------===// +// TypeConstraintDecl + +/// The class represents a Type constraint, and constrains a variable to be a +/// Type. +class TypeConstraintDecl + : public Node::NodeBase { +public: + static TypeConstraintDecl *create(Context &ctx, llvm::SMRange loc); + +protected: + using Base::Base; +}; + +//===----------------------------------------------------------------------===// +// TypeRangeConstraintDecl + +/// The class represents a TypeRange constraint, and constrains a variable to be +/// a TypeRange. +class TypeRangeConstraintDecl + : public Node::NodeBase { +public: + static TypeRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc); + +protected: + using Base::Base; +}; + +//===----------------------------------------------------------------------===// +// ValueConstraintDecl + +/// The class represents a Value constraint, and constrains a variable to be a +/// Value. +class ValueConstraintDecl + : public Node::NodeBase { +public: + static ValueConstraintDecl *create(Context &ctx, llvm::SMRange loc, + Expr *typeExpr); + + /// Return the optional type the value is constrained to. + Expr *getTypeExpr() { return typeExpr; } + const Expr *getTypeExpr() const { return typeExpr; } + +protected: + ValueConstraintDecl(llvm::SMRange loc, Expr *typeExpr) + : Base(loc), typeExpr(typeExpr) {} + + /// An optional type that the value is constrained to. + Expr *typeExpr; +}; + +//===----------------------------------------------------------------------===// +// ValueRangeConstraintDecl + +/// The class represents a ValueRange constraint, and constrains a variable to +/// be a ValueRange. +class ValueRangeConstraintDecl + : public Node::NodeBase { +public: + static ValueRangeConstraintDecl *create(Context &ctx, llvm::SMRange loc, + Expr *typeExpr); + + /// Return the optional type the value range is constrained to. + Expr *getTypeExpr() { return typeExpr; } + const Expr *getTypeExpr() const { return typeExpr; } + +protected: + ValueRangeConstraintDecl(llvm::SMRange loc, Expr *typeExpr) + : Base(loc), typeExpr(typeExpr) {} + + /// An optional type that the value range is constrained to. + Expr *typeExpr; +}; + +//===----------------------------------------------------------------------===// +// OpNameDecl +//===----------------------------------------------------------------------===// + +/// This Decl represents an OperationName. +class OpNameDecl : public Node::NodeBase { +public: + static OpNameDecl *create(Context &ctx, const Name &name); + static OpNameDecl *create(Context &ctx, llvm::SMRange loc); + + /// Return the name of this operation, or none if the name is unknown. + Optional getName() const { + const Optional &name = Decl::getName(); + return name ? Optional(name->name) : llvm::None; + } + +private: + explicit OpNameDecl(Name name) : Base(name.location, name) {} + explicit OpNameDecl(llvm::SMRange loc) : Base(loc) {} +}; + +//===----------------------------------------------------------------------===// +// PatternDecl +//===----------------------------------------------------------------------===// + +/// This Decl represents a single Pattern. +class PatternDecl : public Node::NodeBase { +public: + static PatternDecl *create(Context &ctx, llvm::SMRange location, + Optional name, Optional benefit, + bool hasBoundedRecursion, + const CompoundStmt *body); + + /// Return the benefit of this pattern if specified, or None. + Optional getBenefit() const { return benefit; } + + /// Return if this pattern has bounded rewrite recursion. + bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; } + + /// Return the body of this pattern. + const CompoundStmt *getBody() const { return patternBody; } + + /// Return the root rewrite statement of this pattern. + const OpRewriteStmt *getRootRewriteStmt() const { + return cast(patternBody->getChildren().back()); + } + +private: + PatternDecl(llvm::SMRange loc, Optional name, + Optional benefit, bool hasBoundedRecursion, + const CompoundStmt *body) + : Base(loc, name), benefit(benefit), + hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {} + + /// The benefit of the pattern if it was explicitly specified, None otherwise. + Optional benefit; + + /// If the pattern has properly bounded rewrite recursion or not. + bool hasBoundedRecursion; + + /// The compound statement representing the body of the pattern. + const CompoundStmt *patternBody; +}; + +//===----------------------------------------------------------------------===// +// VariableDecl +//===----------------------------------------------------------------------===// + +/// This Decl represents the definition of a PDLL variable. +class VariableDecl final + : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static VariableDecl *create(Context &ctx, Name name, Type type, + Expr *initExpr, + ArrayRef constraints); + + /// Return the constraints of this variable. + ArrayRef getConstraints() const { + return {getTrailingObjects(), numConstraints}; + } + + /// Return the initializer expression of this statement, or nullptr if there + /// was no initializer. + Expr *getInitExpr() const { return initExpr; } + + /// Return the name of the decl. + const Name &getName() const { return *Decl::getName(); } + + /// Return the type of the decl. + Type getType() const { return type; } + +private: + VariableDecl(Name name, Type type, Expr *initExpr, unsigned numConstraints) + : Base(name.location, name), type(type), initExpr(initExpr), + numConstraints(numConstraints) {} + + /// The type of the variable. + Type type; + + /// The optional initializer expression of this statement. + Expr *initExpr; + + /// The number of constraints attached to this variable. + unsigned numConstraints; + + /// Allow access to various internals. + friend llvm::TrailingObjects; +}; + +//===----------------------------------------------------------------------===// +// Module +//===----------------------------------------------------------------------===// + +/// This class represents a top-level AST module. +class Module final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static Module *create(Context &ctx, llvm::SMLoc loc, + ArrayRef children); + + /// Return the children of this module. + MutableArrayRef getChildren() { + return llvm::makeMutableArrayRef(getTrailingObjects(), numChildren); + } + ArrayRef getChildren() const { + return const_cast(this)->getChildren(); + } + +private: + Module(llvm::SMLoc loc, unsigned numChildren) + : Base(llvm::SMRange{loc, loc}), numChildren(numChildren) {} + + /// The number of decls held by this module. + unsigned numChildren; + + /// Allow access to various internals. + friend llvm::TrailingObjects; +}; + +//===----------------------------------------------------------------------===// +// Defered Method Definitions +//===----------------------------------------------------------------------===// + +inline bool Decl::classof(const Node *node) { + return isa(node); +} + +inline bool ConstraintDecl::classof(const Node *node) { + return isa(node); +} + +inline bool CoreConstraintDecl::classof(const Node *node) { + return isa(node); +} + +inline bool Expr::classof(const Node *node) { + return isa(node); +} + +inline bool OpRewriteStmt::classof(const Node *node) { + return isa(node); +} + +inline bool Stmt::classof(const Node *node) { + return isa(node); +} + +} // namespace ast +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_AST_NODES_H_ diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h @@ -0,0 +1,264 @@ +//===- Types.h --------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_AST_TYPES_H_ +#define MLIR_TOOLS_PDLL_AST_TYPES_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/Support/StorageUniquer.h" + +namespace mlir { +namespace pdll { +namespace ast { +class Context; + +namespace detail { +struct AttributeTypeStorage; +struct ConstraintTypeStorage; +struct OperationTypeStorage; +struct RangeTypeStorage; +struct TypeTypeStorage; +struct ValueTypeStorage; +} // namespace detail + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +class Type { +public: + /// This class represents the internal storage of the Type class. + struct Storage; + + /// This class provides several utilities when defining derived type classes. + template + class TypeBase : public BaseT { + public: + using Base = TypeBase; + using ImplTy = ImplT; + using BaseT::BaseT; + + /// Provide type casting support. + static bool classof(Type type) { + return type.getTypeID() == TypeID::get(); + } + }; + + Type(Storage *impl = nullptr) : impl(impl) {} + Type(const Type &other) = default; + + bool operator==(const Type &other) const { return impl == other.impl; } + bool operator!=(const Type &other) const { return !(*this == other); } + explicit operator bool() const { return impl; } + + /// Provide type casting support. + template + bool isa() const { + assert(impl && "isa<> used on a null type."); + return U::classof(*this); + } + template + bool isa() const { + return isa() || isa(); + } + template + U dyn_cast() const { + return isa() ? U(impl) : U(nullptr); + } + template + U dyn_cast_or_null() const { + return (impl && isa()) ? U(impl) : U(nullptr); + } + template + U cast() const { + assert(isa()); + return U(impl); + } + + /// Return the internal storage instance of this type. + Storage *getImpl() const { return impl; } + + /// Return the TypeID instance of this type. + TypeID getTypeID() const; + + /// Print this type to the given stream. + void print(raw_ostream &os) const; + + /// Try to refine this type with the one provided. Given two compatible types, + /// this will return a merged type contains as much detail from the two types. + /// For example, if refining two operation types and one contains a name, + /// while the other doesn't, the refined type contains the name. If the two + /// types are incompatible, null is returned. + Type refineWith(Type other) const; + +protected: + /// Return the internal storage instance of this type reinterpreted as the + /// given derived storage type. + template + const T *getImplAs() const { + return static_cast(impl); + } + +private: + Storage *impl; +}; + +inline llvm::hash_code hash_value(Type type) { + return DenseMapInfo::getHashValue(type.getImpl()); +} + +inline raw_ostream &operator<<(raw_ostream &os, Type type) { + type.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// AttributeType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to an mlir::Attribute. +class AttributeType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Attribute type. + static AttributeType get(Context &context); +}; + +//===----------------------------------------------------------------------===// +// ConstraintType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to a constraint. This +/// type has no MLIR C++ API correspondance. +class ConstraintType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Constraint type. + static ConstraintType get(Context &context); +}; + +//===----------------------------------------------------------------------===// +// OperationType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to an mlir::Operation. +class OperationType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Operation type with an optional operation name. + /// If no name is provided, this type may refer to any operation. + static OperationType get(Context &context, + Optional name = llvm::None); + + /// Return the name of this operation type, or None if it doesn't have on. + Optional getName() const; +}; + +//===----------------------------------------------------------------------===// +// RangeType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to a range of elements +/// with a given element type. +class RangeType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Range type with the given element type. + static RangeType get(Context &context, Type elementType); + + /// Return the element type of this range. + Type getElementType() const; +}; + +//===----------------------------------------------------------------------===// +// TypeRangeType + +/// This class represents a PDLL type that corresponds to an mlir::TypeRange. +class TypeRangeType : public RangeType { +public: + using RangeType::RangeType; + + /// Provide type casting support. + static bool classof(Type type); + + /// Return an instance of the TypeRange type. + static TypeRangeType get(Context &context); +}; + +//===----------------------------------------------------------------------===// +// ValueRangeType + +/// This class represents a PDLL type that corresponds to an mlir::ValueRange. +class ValueRangeType : public RangeType { +public: + using RangeType::RangeType; + + /// Provide type casting support. + static bool classof(Type type); + + /// Return an instance of the ValueRange type. + static ValueRangeType get(Context &context); +}; + +//===----------------------------------------------------------------------===// +// TypeType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to an mlir::Type. +class TypeType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Type type. + static TypeType get(Context &context); +}; + +//===----------------------------------------------------------------------===// +// ValueType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to an mlir::Value. +class ValueType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Value type. + static ValueType get(Context &context); +}; + +} // namespace ast +} // namespace pdll +} // namespace mlir + +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::pdll::ast::Type getEmptyKey() { + void *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::pdll::ast::Type( + static_cast(pointer)); + } + static mlir::pdll::ast::Type getTombstoneKey() { + void *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::pdll::ast::Type( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::pdll::ast::Type val) { + return llvm::hash_value(val.getImpl()); + } + static bool isEqual(mlir::pdll::ast::Type lhs, mlir::pdll::ast::Type rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // MLIR_TOOLS_PDLL_AST_TYPES_H_ diff --git a/mlir/include/mlir/Tools/PDLL/Parser/Parser.h b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/Parser/Parser.h @@ -0,0 +1,34 @@ +//===- Parser.h - MLIR PDLL Frontend Parser ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_PARSER_PARSER_H_ +#define MLIR_TOOLS_PDLL_PARSER_PARSER_H_ + +#include + +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class SourceMgr; +} // namespace llvm + +namespace mlir { +namespace pdll { +namespace ast { +class Context; +class Module; +} // namespace ast + +/// Parse an AST module from the main file of the given source manager. +FailureOr +parsePDLAST(ast::Context &context, llvm::SourceMgr &sourceMgr, + const std::vector &includeDirs); +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_PARSER_PARSER_H_ diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -0,0 +1,23 @@ +//===- Context.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/AST/Context.h" +#include "TypeDetail.h" + +using namespace mlir; +using namespace mlir::pdll::ast; + +Context::Context() { + typeUniquer.registerSingletonStorageType(); + typeUniquer.registerSingletonStorageType(); + typeUniquer.registerSingletonStorageType(); + typeUniquer.registerSingletonStorageType(); + + typeUniquer.registerParametricStorageType(); + typeUniquer.registerParametricStorageType(); +} diff --git a/mlir/lib/Tools/PDLL/AST/Diagnostic.cpp b/mlir/lib/Tools/PDLL/AST/Diagnostic.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/Diagnostic.cpp @@ -0,0 +1,26 @@ +//===- Diagnostic.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/AST/Diagnostic.h" + +using namespace mlir; +using namespace mlir::pdll::ast; + +//===----------------------------------------------------------------------===// +// InFlightDiagnostic +//===----------------------------------------------------------------------===// + +void InFlightDiagnostic::report() { + // If this diagnostic is still inflight and it hasn't been abandoned, then + // report it. + if (isInFlight()) { + owner->report(std::move(*impl)); + owner = nullptr; + } + impl.reset(); +} diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -0,0 +1,266 @@ +//===- NodePrinter.cpp ----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/AST/Context.h" +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/ScopedPrinter.h" + +using namespace mlir; +using namespace mlir::pdll::ast; + +//===----------------------------------------------------------------------===// +// NodePrinter +//===----------------------------------------------------------------------===// + +namespace { +class NodePrinter { +public: + NodePrinter(raw_ostream &os) : os(os) {} + + /// Print the given type to the stream. + void print(Type type); + + /// Print the given node to the stream. + void print(const Node *node); + +private: + /// Print a range containing children of a node. + template ::value> + * = nullptr> + void printChildren(RangeT &&range) { + if (llvm::empty(range)) + return; + + // Print the first N-1 elements with a prefix of "|-". + auto it = std::begin(range); + for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) + print(*it); + + // Print the last element. + elementIndentStack.back() = true; + print(*it); + } + template ::value> + * = nullptr> + void printChildren(RangeT &&range, OthersT &&...others) { + printChildren(ArrayRef({range, others...})); + } + /// Print a range containing children of a node, nesting the children under + /// the given label. + template + void printChildren(StringRef label, RangeT &&range) { + if (llvm::empty(range)) + return; + elementIndentStack.reserve(elementIndentStack.size() + 1); + llvm::SaveAndRestore lastElement(elementIndentStack.back(), true); + + printIndent(); + os << label << "`\n"; + elementIndentStack.push_back(/*isLastElt*/ false); + printChildren(std::forward(range)); + elementIndentStack.pop_back(); + } + + /// Print the given derived node to the stream. + void printImpl(const CompoundStmt *stmt); + void printImpl(const EraseStmt *stmt); + void printImpl(const LetStmt *stmt); + + void printImpl(const DeclRefExpr *expr); + void printImpl(const MemberAccessExpr *expr); + + void printImpl(const AttrConstraintDecl *decl); + void printImpl(const OpConstraintDecl *decl); + void printImpl(const TypeConstraintDecl *decl); + void printImpl(const TypeRangeConstraintDecl *decl); + void printImpl(const ValueConstraintDecl *decl); + void printImpl(const ValueRangeConstraintDecl *decl); + void printImpl(const OpNameDecl *decl); + void printImpl(const PatternDecl *decl); + void printImpl(const VariableDecl *decl); + void printImpl(const Module *module); + + /// Print the current indent stack. + void printIndent() { + if (elementIndentStack.empty()) + return; + + for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back()) + os << (isLastElt ? " " : " |"); + os << (elementIndentStack.back() ? " `" : " |"); + } + + /// The raw output stream. + raw_ostream &os; + + /// A stack of indents and a flag indicating if the current element being + /// printed at that indent is the last element. + SmallVector elementIndentStack; +}; +} // namespace + +void NodePrinter::print(Type type) { + // Protect against invalid inputs. + if (!type) { + os << "Type"; + return; + } + + TypeSwitch(type) + .Case([&](AttributeType) { os << "Attr"; }) + .Case([&](ConstraintType) { os << "Constraint"; }) + .Case([&](OperationType type) { + os << "Op"; + if (Optional name = type.getName()) + os << "<" << *name << ">"; + }) + .Case([&](RangeType type) { + print(type.getElementType()); + os << "Range"; + }) + .Case([&](TypeType) { os << "Type"; }) + .Case([&](ValueType) { os << "Value"; }) + .Default([](Type) { llvm_unreachable("unknown AST type"); }); +} + +void NodePrinter::print(const Node *node) { + printIndent(); + os << "-"; + + elementIndentStack.push_back(/*isLastElt*/ false); + TypeSwitch(node) + .Case< + // Statements. + const CompoundStmt, const EraseStmt, const LetStmt, + + // Expressions. + const DeclRefExpr, const MemberAccessExpr, + + // Decls. + const AttrConstraintDecl, const OpConstraintDecl, + const TypeConstraintDecl, const TypeRangeConstraintDecl, + const ValueConstraintDecl, const ValueRangeConstraintDecl, + const OpNameDecl, const PatternDecl, const VariableDecl, + + const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) + .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); + elementIndentStack.pop_back(); +} + +void NodePrinter::printImpl(const CompoundStmt *stmt) { + os << "CompoundStmt " << stmt << "\n"; + printChildren(stmt->getChildren()); +} + +void NodePrinter::printImpl(const EraseStmt *stmt) { + os << "EraseStmt " << stmt << "\n"; + printChildren(stmt->getRootOpExpr()); +} + +void NodePrinter::printImpl(const LetStmt *stmt) { + os << "LetStmt " << stmt << "\n"; + printChildren(stmt->getVarDecl()); +} + +void NodePrinter::printImpl(const DeclRefExpr *expr) { + os << "DeclRefExpr " << expr << " Type<"; + print(expr->getType()); + os << ">\n"; + printChildren(expr->getDecl()); +} + +void NodePrinter::printImpl(const MemberAccessExpr *expr) { + os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() + << "> Type<"; + print(expr->getType()); + os << ">\n"; + printChildren(expr->getParentExpr()); +} + +void NodePrinter::printImpl(const AttrConstraintDecl *decl) { + os << "AttrConstraintDecl " << decl << "\n"; + if (const auto *typeExpr = decl->getTypeExpr()) + printChildren(typeExpr); +} + +void NodePrinter::printImpl(const OpConstraintDecl *decl) { + os << "OpConstraintDecl " << decl << "\n"; + printChildren(decl->getNameDecl()); +} + +void NodePrinter::printImpl(const TypeConstraintDecl *decl) { + os << "TypeConstraintDecl " << decl << "\n"; +} + +void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { + os << "TypeRangeConstraintDecl " << decl << "\n"; +} + +void NodePrinter::printImpl(const ValueConstraintDecl *decl) { + os << "ValueConstraintDecl " << decl << "\n"; + if (const auto *typeExpr = decl->getTypeExpr()) + printChildren(typeExpr); +} + +void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { + os << "ValueRangeConstraintDecl " << decl << "\n"; + if (const auto *typeExpr = decl->getTypeExpr()) + printChildren(typeExpr); +} + +void NodePrinter::printImpl(const OpNameDecl *decl) { + os << "OpNameDecl " << decl; + if (Optional name = decl->getName()) + os << " Name<" << name << ">"; + os << "\n"; +} + +void NodePrinter::printImpl(const PatternDecl *decl) { + os << "PatternDecl " << decl; + if (const Optional &name = decl->getName()) + os << " Name<" << name->name << ">"; + if (Optional benefit = decl->getBenefit()) + os << " Benefit<" << *benefit << ">"; + if (decl->hasBoundedRewriteRecursion()) + os << " Recursion"; + + os << "\n"; + printChildren(decl->getBody()); +} + +void NodePrinter::printImpl(const VariableDecl *decl) { + os << "VariableDecl " << decl << " Name<" << decl->getName().name + << "> Type<"; + print(decl->getType()); + os << ">\n"; + if (Expr *initExpr = decl->getInitExpr()) + printChildren(initExpr); + + auto constraints = + llvm::map_range(decl->getConstraints(), + [](const ConstraintRef &ref) { return ref.constraint; }); + printChildren("Constraints", constraints); +} + +void NodePrinter::printImpl(const Module *module) { + os << "Module " << module << "\n"; + printChildren(module->getChildren()); +} + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + +void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); } + +void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -0,0 +1,223 @@ +//===- Nodes.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/AST/Context.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::pdll::ast; + +/// Copy a string reference into the context with a null terminator. +static StringRef copyStringWithNull(Context &ctx, StringRef str) { + if (str.empty()) + return str; + + char *data = ctx.getAllocator().Allocate(str.size() + 1); + std::copy(str.begin(), str.end(), data); + data[str.size()] = 0; + return StringRef(data, str.size()); +} + +//===----------------------------------------------------------------------===// +// DeclScope +//===----------------------------------------------------------------------===// + +void DeclScope::add(Decl *decl) { + assert(decl->getName() && "expected a named decl"); + assert(!decls.count(decl->getName()->name) && + "decl with this name already exists"); + decls.try_emplace(decl->getName()->name, decl); +} + +Decl *DeclScope::lookup(StringRef name) { + if (Decl *decl = decls.lookup(name)) + return decl; + return parent ? parent->lookup(name) : nullptr; +} + +//===----------------------------------------------------------------------===// +// CompoundStmt +//===----------------------------------------------------------------------===// + +CompoundStmt *CompoundStmt::create(Context &ctx, llvm::SMRange loc, + ArrayRef children) { + unsigned allocSize = CompoundStmt::totalSizeToAlloc(children.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt)); + + CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size()); + std::uninitialized_copy(children.begin(), children.end(), stmt->begin()); + return stmt; +} + +//===----------------------------------------------------------------------===// +// LetStmt +//===----------------------------------------------------------------------===// + +LetStmt *LetStmt::create(Context &ctx, llvm::SMRange loc, + VariableDecl *varDecl) { + return new (ctx.getAllocator().Allocate()) LetStmt(loc, varDecl); +} + +//===----------------------------------------------------------------------===// +// OpRewriteStmt +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// EraseStmt + +EraseStmt *EraseStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp) { + return new (ctx.getAllocator().Allocate()) EraseStmt(loc, rootOp); +} + +//===----------------------------------------------------------------------===// +// DeclRefExpr +//===----------------------------------------------------------------------===// + +DeclRefExpr *DeclRefExpr::create(Context &ctx, llvm::SMRange loc, Decl *decl, + Type type) { + return new (ctx.getAllocator().Allocate()) + DeclRefExpr(loc, decl, type); +} + +//===----------------------------------------------------------------------===// +// MemberAccessExpr +//===----------------------------------------------------------------------===// + +MemberAccessExpr *MemberAccessExpr::create(Context &ctx, llvm::SMRange loc, + const Expr *parentExpr, + StringRef memberName, Type type) { + return new (ctx.getAllocator().Allocate()) MemberAccessExpr( + loc, parentExpr, memberName.copy(ctx.getAllocator()), type); +} + +//===----------------------------------------------------------------------===// +// AttrConstraintDecl +//===----------------------------------------------------------------------===// + +AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, llvm::SMRange loc, + Expr *typeExpr) { + return new (ctx.getAllocator().Allocate()) + AttrConstraintDecl(loc, typeExpr); +} + +//===----------------------------------------------------------------------===// +// OpConstraintDecl +//===----------------------------------------------------------------------===// + +OpConstraintDecl *OpConstraintDecl::create(Context &ctx, llvm::SMRange loc, + const OpNameDecl *nameDecl) { + if (!nameDecl) + nameDecl = OpNameDecl::create(ctx, llvm::SMRange()); + + return new (ctx.getAllocator().Allocate()) + OpConstraintDecl(loc, nameDecl); +} + +Optional OpConstraintDecl::getName() const { + return getNameDecl()->getName(); +} + +//===----------------------------------------------------------------------===// +// TypeConstraintDecl +//===----------------------------------------------------------------------===// + +TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, + llvm::SMRange loc) { + return new (ctx.getAllocator().Allocate()) + TypeConstraintDecl(loc); +} + +//===----------------------------------------------------------------------===// +// TypeRangeConstraintDecl +//===----------------------------------------------------------------------===// + +TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx, + llvm::SMRange loc) { + return new (ctx.getAllocator().Allocate()) + TypeRangeConstraintDecl(loc); +} + +//===----------------------------------------------------------------------===// +// ValueConstraintDecl +//===----------------------------------------------------------------------===// + +ValueConstraintDecl * +ValueConstraintDecl::create(Context &ctx, llvm::SMRange loc, Expr *typeExpr) { + return new (ctx.getAllocator().Allocate()) + ValueConstraintDecl(loc, typeExpr); +} + +//===----------------------------------------------------------------------===// +// ValueRangeConstraintDecl +//===----------------------------------------------------------------------===// + +ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx, + llvm::SMRange loc, + Expr *typeExpr) { + return new (ctx.getAllocator().Allocate()) + ValueRangeConstraintDecl(loc, typeExpr); +} + +//===----------------------------------------------------------------------===// +// OpNameDecl +//===----------------------------------------------------------------------===// + +OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) { + return new (ctx.getAllocator().Allocate()) OpNameDecl(name); +} +OpNameDecl *OpNameDecl::create(Context &ctx, llvm::SMRange loc) { + return new (ctx.getAllocator().Allocate()) OpNameDecl(loc); +} + +//===----------------------------------------------------------------------===// +// PatternDecl +//===----------------------------------------------------------------------===// + +PatternDecl *PatternDecl::create(Context &ctx, llvm::SMRange loc, + Optional name, + Optional benefit, + bool hasBoundedRecursion, + const CompoundStmt *body) { + return new (ctx.getAllocator().Allocate()) + PatternDecl(loc, name, benefit, hasBoundedRecursion, body); +} + +//===----------------------------------------------------------------------===// +// VariableDecl +//===----------------------------------------------------------------------===// + +VariableDecl *VariableDecl::create(Context &ctx, Name name, Type type, + Expr *initExpr, + ArrayRef constraints) { + unsigned allocSize = + VariableDecl::totalSizeToAlloc(constraints.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl)); + + VariableDecl *varDecl = + new (rawData) VariableDecl(name, type, initExpr, constraints.size()); + std::uninitialized_copy(constraints.begin(), constraints.end(), + varDecl->getConstraints().begin()); + return varDecl; +} + +//===----------------------------------------------------------------------===// +// Module +//===----------------------------------------------------------------------===// + +Module *Module::create(Context &ctx, llvm::SMLoc loc, + ArrayRef children) { + unsigned allocSize = Module::totalSizeToAlloc(children.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module)); + + Module *module = new (rawData) Module(loc, children.size()); + std::uninitialized_copy(children.begin(), children.end(), + module->getChildren().begin()); + return module; +} diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h @@ -0,0 +1,116 @@ +//===- TypeDetail.h ---------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ +#define LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ + +#include "mlir/Tools/PDLL/AST/Types.h" + +namespace mlir { +namespace pdll { +namespace ast { +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +struct Type::Storage : public StorageUniquer::BaseStorage { + Storage(TypeID typeID) : typeID(typeID) {} + + /// The type identifier for the derived type class. + TypeID typeID; +}; + +namespace detail { + +/// A utility CRTP base class that defines many of the necessary utilities for +/// defining a PDLL AST Type. +template +struct TypeStorageBase : public Type::Storage { + using KeyTy = KeyT; + using Base = TypeStorageBase; + + /// Construct an instance with the given storage allocator. + static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, + const KeyTy &key) { + return new (alloc.allocate()) ConcreteT(key); + } + + /// Utility methods required by the storage allocator. + bool operator==(const KeyTy &key) const { return this->key == key; } + + /// Return the key value of this storage class. + const KeyTy &getValue() const { return key; } + +protected: + TypeStorageBase(KeyTy key) + : Type::Storage(TypeID::get()), key(key) {} + + KeyTy key; +}; +/// A specialization of the storage base for singleton types. +template +struct TypeStorageBase : public Type::Storage { + using Base = TypeStorageBase; + +protected: + TypeStorageBase() : Type::Storage(TypeID::get()) {} +}; + +//===----------------------------------------------------------------------===// +// AttributeType +//===----------------------------------------------------------------------===// + +struct AttributeTypeStorage : public TypeStorageBase {}; + +//===----------------------------------------------------------------------===// +// ConstraintType +//===----------------------------------------------------------------------===// + +struct ConstraintTypeStorage : public TypeStorageBase {}; + +//===----------------------------------------------------------------------===// +// OperationType +//===----------------------------------------------------------------------===// + +struct OperationTypeStorage + : public TypeStorageBase { + using Base::Base; + + static OperationTypeStorage * + construct(StorageUniquer::StorageAllocator &alloc, StringRef key) { + return new (alloc.allocate()) + OperationTypeStorage(alloc.copyInto(key)); + } +}; + +//===----------------------------------------------------------------------===// +// RangeType +//===----------------------------------------------------------------------===// + +struct RangeTypeStorage : public TypeStorageBase { + using Base::Base; +}; + +//===----------------------------------------------------------------------===// +// TypeType +//===----------------------------------------------------------------------===// + +struct TypeTypeStorage : public TypeStorageBase {}; + +//===----------------------------------------------------------------------===// +// ValueType +//===----------------------------------------------------------------------===// + +struct ValueTypeStorage : public TypeStorageBase {}; + +} // namespace detail +} // namespace ast +} // namespace pdll +} // namespace mlir + +#endif // LIB_MLIR_TOOLS_PDLL_AST_TYPEDETAIL_H_ diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -0,0 +1,124 @@ +//===- Types.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/AST/Types.h" +#include "TypeDetail.h" +#include "mlir/Tools/PDLL/AST/Context.h" + +using namespace mlir; +using namespace mlir::pdll::ast; + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + +TypeID Type::getTypeID() const { return impl->typeID; } + +Type Type::refineWith(Type other) const { + if (*this == other) + return *this; + + // Operation types are compatible if the operation names don't conflict. + if (auto opTy = dyn_cast()) { + auto otherOpTy = other.dyn_cast(); + if (!otherOpTy) + return nullptr; + if (!otherOpTy.getName()) + return *this; + if (!opTy.getName()) + return other; + + return nullptr; + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AttributeType +//===----------------------------------------------------------------------===// + +AttributeType AttributeType::get(Context &context) { + return context.getTypeUniquer().get(); +} + +//===----------------------------------------------------------------------===// +// ConstraintType +//===----------------------------------------------------------------------===// + +ConstraintType ConstraintType::get(Context &context) { + return context.getTypeUniquer().get(); +} + +//===----------------------------------------------------------------------===// +// OperationType +//===----------------------------------------------------------------------===// + +OperationType OperationType::get(Context &context, Optional name) { + return context.getTypeUniquer().get( + /*initFn=*/function_ref(), name.getValueOr("")); +} + +Optional OperationType::getName() const { + StringRef name = getImplAs()->getValue(); + return name.empty() ? Optional() : Optional(name); +} + +//===----------------------------------------------------------------------===// +// RangeType +//===----------------------------------------------------------------------===// + +RangeType RangeType::get(Context &context, Type elementType) { + return context.getTypeUniquer().get( + /*initFn=*/function_ref(), elementType); +} + +Type RangeType::getElementType() const { + return getImplAs()->getValue(); +} + +//===----------------------------------------------------------------------===// +// TypeRangeType + +bool TypeRangeType::classof(Type type) { + RangeType range = type.dyn_cast(); + return range && range.getElementType().isa(); +} + +TypeRangeType TypeRangeType::get(Context &context) { + return RangeType::get(context, TypeType::get(context)).cast(); +} + +//===----------------------------------------------------------------------===// +// ValueRangeType + +bool ValueRangeType::classof(Type type) { + RangeType range = type.dyn_cast(); + return range && range.getElementType().isa(); +} + +ValueRangeType ValueRangeType::get(Context &context) { + return RangeType::get(context, ValueType::get(context)) + .cast(); +} + +//===----------------------------------------------------------------------===// +// TypeType +//===----------------------------------------------------------------------===// + +TypeType TypeType::get(Context &context) { + return context.getTypeUniquer().get(); +} + +//===----------------------------------------------------------------------===// +// ValueType +//===----------------------------------------------------------------------===// + +ValueType ValueType::get(Context &context) { + return context.getTypeUniquer().get(); +} diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -0,0 +1,221 @@ +//===- Lexer.h - MLIR PDLL Frontend Lexer -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_TOOLS_PDLL_PARSER_LEXER_H_ +#define LIB_TOOLS_PDLL_PARSER_LEXER_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" + +namespace llvm { +class SourceMgr; +} // namespace llvm + +namespace mlir { +struct LogicalResult; + +namespace pdll { +namespace ast { +class DiagnosticEngine; +} // namespace ast + +//===----------------------------------------------------------------------===// +// Token +//===----------------------------------------------------------------------===// + +class Token { +public: + enum Kind { + // Markers. + eof, + error, + + // Keywords. + KW_BEGIN, + // Dependent keywords, i.e. those that are treated as keywords depending on + // the current parser context. + KW_DEPENDENT_BEGIN, + kw_attr, + kw_op, + kw_type, + KW_DEPENDENT_END, + + // General keywords. + kw_Attr, + kw_erase, + kw_let, + kw_Constraint, + kw_Op, + kw_OpName, + kw_Pattern, + kw_replace, + kw_rewrite, + kw_Type, + kw_TypeRange, + kw_Value, + kw_ValueRange, + kw_with, + KW_END, + + // Punctuation. + arrow, + colon, + comma, + dot, + equal, + equal_arrow, + semicolon, + // Paired punctuation. + less, + greater, + l_brace, + r_brace, + l_paren, + r_paren, + l_square, + r_square, + underscore, + + // Tokens. + directive, + identifier, + integer, + string_block, + string + }; + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Given a token containing a string literal, return its value, including + /// removing the quote characters and unescaping the contents of the string. + std::string getStringValue() const; + + /// Returns true if the current token is a string literal. + bool isString() const { return isAny(Token::string, Token::string_block); } + + /// Returns true if the current token is a keyword. + bool isKeyword() const { + return kind > Token::KW_BEGIN && kind < Token::KW_END; + } + + /// Returns true if the current token is a keyword in a dependent context, and + /// in any other situation (e.g. variable names) may be treated as an + /// identifier. + bool isDependentKeyword() const { + return kind > Token::KW_DEPENDENT_BEGIN && kind < Token::KW_DEPENDENT_END; + } + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return true if this token is one of the specified kinds. + bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); } + template + bool isAny(Kind k1, Kind k2, Kind k3, T... others) const { + return is(k1) || isAny(k2, k3, others...); + } + + /// Return if the token does not have the given kind. + bool isNot(Kind k) const { return k != kind; } + template + bool isNot(Kind k1, Kind k2, T... others) const { + return !isAny(k1, k2, others...); + } + + /// Return if the token has the given kind. + bool is(Kind K) const { return kind == K; } + + /// Return a location for the start of this token. + llvm::SMLoc getStartLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); + } + /// Return a location at the end of this token. + llvm::SMLoc getEndLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data() + spelling.size()); + } + /// Return a location for the range of this token. + llvm::SMRange getLoc() const { + return llvm::SMRange(getStartLoc(), getEndLoc()); + } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +//===----------------------------------------------------------------------===// +// Lexer +//===----------------------------------------------------------------------===// + +class Lexer { +public: + Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine); + ~Lexer(); + + /// Return a reference to the source manager used by the lexer. + llvm::SourceMgr &getSourceMgr() { return srcMgr; } + + /// Return a reference to the diagnostic engine used by the lexer. + ast::DiagnosticEngine &getDiagEngine() { return diagEngine; } + + /// Push an include of the given file. This will cause the lexer to start + /// processing the provided file. Returns failure if the file could not be + /// opened, success otherwise. + LogicalResult pushInclude(StringRef filename); + + /// Lex the next token and return it. + Token lexToken(); + + /// Change the position of the lexer cursor. The next token we lex will start + /// at the designated point in the input. + void resetPointer(const char *newPointer) { curPtr = newPointer; } + + /// Emit an error to the lexer with the given location and message. + Token emitError(llvm::SMRange loc, const Twine &msg); + Token emitError(const char *loc, const Twine &msg); + Token emitErrorAndNote(llvm::SMRange loc, const Twine &msg, + llvm::SMRange noteLoc, const Twine ¬e); + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// Return the next character in the stream. + int getNextChar(); + + /// Lex methods. + void lexComment(); + Token lexDirective(const char *tokStart); + Token lexIdentifier(const char *tokStart); + Token lexNumber(const char *tokStart); + Token lexString(const char *tokStart, bool isStringBlock); + + llvm::SourceMgr &srcMgr; + int curBufferID; + StringRef curBuffer; + const char *curPtr; + + /// The engine used to emit diagnostics during lexing/parsing. + ast::DiagnosticEngine &diagEngine; + + /// A flag indicating if we added a default diagnostic handler to the provided + /// diagEngine. + bool addedHandlerToDiagEngine; +}; +} // namespace pdll +} // namespace mlir + +#endif // LIB_TOOLS_PDLL_PARSER_LEXER_H_ diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -0,0 +1,377 @@ +//===- Lexer.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Lexer.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/PDLL/AST/Diagnostic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; +using namespace mlir::pdll; + +//===----------------------------------------------------------------------===// +// Token +//===----------------------------------------------------------------------===// + +std::string Token::getStringValue() const { + assert(getKind() == string || getKind() == string_block); + + // Start by dropping the quotes. + StringRef bytes = getSpelling().drop_front().drop_back(); + if (is(string_block)) + bytes = bytes.drop_front().drop_back(); + + std::string result; + result.reserve(bytes.size()); + for (unsigned i = 0, e = bytes.size(); i != e;) { + auto c = bytes[i++]; + if (c != '\\') { + result.push_back(c); + continue; + } + + assert(i + 1 <= e && "invalid string should be caught by lexer"); + auto c1 = bytes[i++]; + switch (c1) { + case '"': + case '\\': + result.push_back(c1); + continue; + case 'n': + result.push_back('\n'); + continue; + case 't': + result.push_back('\t'); + continue; + default: + break; + } + + assert(i + 1 <= e && "invalid string should be caught by lexer"); + auto c2 = bytes[i++]; + + assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape"); + result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2)); + } + + return result; +} + +//===----------------------------------------------------------------------===// +// Lexer +//===----------------------------------------------------------------------===// + +Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine) + : srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) { + curBufferID = mgr.getMainFileID(); + curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); + curPtr = curBuffer.begin(); + + // If the diag engine has no handler, add a default that emits to the + // SourceMgr. + if (!diagEngine.getHandlerFn()) { + diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) { + srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(), + diag.getMessage()); + for (const ast::Diagnostic ¬e : diag.getNotes()) + srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(), + note.getMessage()); + }); + addedHandlerToDiagEngine = true; + } +} + +Lexer::~Lexer() { + if (addedHandlerToDiagEngine) + diagEngine.setHandlerFn(nullptr); +} + +LogicalResult Lexer::pushInclude(StringRef filename) { + std::string includedFile; + int bufferID = srcMgr.AddIncludeFile( + filename.str(), llvm::SMLoc::getFromPointer(curPtr), includedFile); + if (!bufferID) + return failure(); + + curBufferID = bufferID; + curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); + curPtr = curBuffer.begin(); + return success(); +} + +Token Lexer::emitError(llvm::SMRange loc, const Twine &msg) { + diagEngine.emitError(loc, msg); + return formToken(Token::error, loc.Start.getPointer()); +} +Token Lexer::emitErrorAndNote(llvm::SMRange loc, const Twine &msg, + llvm::SMRange noteLoc, const Twine ¬e) { + diagEngine.emitError(loc, msg)->attachNote(note, noteLoc); + return formToken(Token::error, loc.Start.getPointer()); +} +Token Lexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMRange(llvm::SMLoc::getFromPointer(loc), + llvm::SMLoc::getFromPointer(loc + 1)), + msg); +} + +int Lexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return static_cast(curChar); + case 0: { + // A nul character in the stream is either the end of the current buffer + // or a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +Token Lexer::lexToken() { + while (true) { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + case EOF: { + // Return EOF denoting the end of lexing. + Token eof = formToken(Token::eof, tokStart); + + // Check to see if we are in an included file. + llvm::SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID); + if (parentIncludeLoc.isValid()) { + curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc); + curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer(); + curPtr = parentIncludeLoc.getPointer(); + } + + return eof; + } + + // Lex punctuation. + case '-': + if (*curPtr == '>') { + ++curPtr; + return formToken(Token::arrow, tokStart); + } + return emitError(tokStart, "unexpected character"); + case ':': + return formToken(Token::colon, tokStart); + case ',': + return formToken(Token::comma, tokStart); + case '.': + return formToken(Token::dot, tokStart); + case '=': + if (*curPtr == '>') { + ++curPtr; + return formToken(Token::equal_arrow, tokStart); + } + return formToken(Token::equal, tokStart); + case ';': + return formToken(Token::semicolon, tokStart); + case '[': + if (*curPtr == '{') { + ++curPtr; + return lexString(tokStart, /*isStringBlock=*/true); + } + return formToken(Token::l_square, tokStart); + case ']': + return formToken(Token::r_square, tokStart); + + case '<': + return formToken(Token::less, tokStart); + case '>': + return formToken(Token::greater, tokStart); + case '{': + return formToken(Token::l_brace, tokStart); + case '}': + return formToken(Token::r_brace, tokStart); + case '(': + return formToken(Token::l_paren, tokStart); + case ')': + return formToken(Token::r_paren, tokStart); + case '/': + if (*curPtr == '/') { + lexComment(); + continue; + } + return emitError(tokStart, "unexpected character"); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + + case '#': + return lexDirective(tokStart); + case '"': + return lexString(tokStart, /*isStringBlock=*/false); + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return lexNumber(tokStart); + } + } +} + +/// Skip a comment line, starting with a '//'. +void Lexer::lexComment() { + // Advance over the second '/' in a '//' comment. + assert(*curPtr == '/'); + ++curPtr; + + while (true) { + switch (*curPtr++) { + case '\n': + case '\r': + // Newline is end of comment. + return; + case 0: + // If this is the end of the buffer, end the comment. + if (curPtr - 1 == curBuffer.end()) { + --curPtr; + return; + } + LLVM_FALLTHROUGH; + default: + // Skip over other characters. + break; + } + } +} + +Token Lexer::lexDirective(const char *tokStart) { + // Match the rest with an identifier regex: [0-9a-zA-Z_]* + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + + StringRef str(tokStart, curPtr - tokStart); + return Token(Token::directive, str); +} + +Token Lexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_]* + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + Token::Kind kind = StringSwitch(str) + .Case("attr", Token::kw_attr) + .Case("Attr", Token::kw_Attr) + .Case("erase", Token::kw_erase) + .Case("let", Token::kw_let) + .Case("Constraint", Token::kw_Constraint) + .Case("op", Token::kw_op) + .Case("Op", Token::kw_Op) + .Case("OpName", Token::kw_OpName) + .Case("Pattern", Token::kw_Pattern) + .Case("replace", Token::kw_replace) + .Case("rewrite", Token::kw_rewrite) + .Case("type", Token::kw_type) + .Case("Type", Token::kw_Type) + .Case("TypeRange", Token::kw_TypeRange) + .Case("Value", Token::kw_Value) + .Case("ValueRange", Token::kw_ValueRange) + .Case("with", Token::kw_with) + .Case("_", Token::underscore) + .Default(Token::identifier); + return Token(kind, str); +} + +Token Lexer::lexNumber(const char *tokStart) { + assert(isdigit(curPtr[-1])); + + // Handle the normal decimal case. + while (isdigit(*curPtr)) + ++curPtr; + + return formToken(Token::integer, tokStart); +} + +Token Lexer::lexString(const char *tokStart, bool isStringBlock) { + while (true) { + switch (*curPtr++) { + case '"': + // If this is a string block, we only end the string when we encounter a + // `}]`. + if (!isStringBlock) + return formToken(Token::string, tokStart); + continue; + case '}': + // If this is a string block, we only end the string when we encounter a + // `}]`. + if (!isStringBlock || *curPtr != ']') + continue; + ++curPtr; + return formToken(Token::string_block, tokStart); + case 0: + // If this is a random nul character in the middle of a string, just + // include it. If it is the end of file, then it is an error. + if (curPtr - 1 != curBuffer.end()) + continue; + LLVM_FALLTHROUGH; + case '\n': + case '\v': + case '\f': + // String blocks allow multiple lines. + if (!isStringBlock) + return emitError(curPtr - 1, "expected '\"' in string literal"); + continue; + + case '\\': + // Handle explicitly a few escapes. + if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || + *curPtr == 't') { + ++curPtr; + } else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) { + // Support \xx for two hex digits. + curPtr += 2; + } else { + return emitError(curPtr - 1, "unknown escape in string literal"); + } + continue; + + default: + continue; + } + } +} diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -0,0 +1,1119 @@ +//===- Parser.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/Parser/Parser.h" +#include "Lexer.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/PDLL/AST/Context.h" +#include "mlir/Tools/PDLL/AST/Diagnostic.h" +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/AST/Types.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatAdapters.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/SaveAndRestore.h" +#include + +using namespace mlir; +using namespace mlir::pdll; + +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +namespace { +class Parser { +public: + Parser(ast::Context &context, llvm::SourceMgr &sourceMgr, + const std::vector &includeDirs) + : context(context), lexer(sourceMgr, context.getDiagEngine()), + curToken(lexer.lexToken()), curDeclScope(nullptr), + includeDirs(includeDirs), valueTy(ast::ValueType::get(context)), + valueRangeTy(ast::ValueRangeType::get(context)), + typeTy(ast::TypeType::get(context)), + typeRangeTy(ast::TypeRangeType::get(context)) {} + + /// Try to parse a new module. Returns nullptr in the case of failure. + FailureOr parseModule(); + +private: + /// The current context of the parser. It allows for the parser to know a bit + /// about the construct it is nested within during parsing. This is used + /// specifically to provide additional verification during parsing, e.g. to + /// prevent using rewrites within a match context, matcher constraints within + /// a rewrite section, etc. + enum class ParserContext { + /// The parser is in the global context. + Global, + /// The parser is currently within the matcher portion of a Pattern, which + /// is allows a terminal operation rewrite statement but no other rewrite + /// transformations. + PatternMatch, + }; + + //===--------------------------------------------------------------------===// + // Parsing + //===--------------------------------------------------------------------===// + + /// Push a new decl scope onto the lexer. + ast::DeclScope *pushDeclScope() { + ast::DeclScope *newScope = + new (scopeAllocator.Allocate()) ast::DeclScope(curDeclScope); + return (curDeclScope = newScope); + } + void pushDeclScope(ast::DeclScope *scope) { curDeclScope = scope; } + + /// Pop the last decl scope from the lexer. + void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } + + /// Parse the body of an AST module. + LogicalResult parseModuleBody(SmallVector &decls); + + /// Try to convert the given expression to `type`. Returns failure and emits + /// an error if a conversion is not viable. On failure, `noteAttachFn` is + /// invoked to attach notes to the emitted error diagnostic. On success, + /// `expr` is updated to the expression used to convert to `type`. + LogicalResult convertExpressionTo( + ast::Expr *&expr, ast::Type type, + function_ref noteAttachFn = {}); + + //===--------------------------------------------------------------------===// + // Directives + + LogicalResult parseDirective(SmallVector &decls); + LogicalResult parseInclude(SmallVector &decls); + + //===--------------------------------------------------------------------===// + // Decls + + FailureOr parseTopLevelDecl(); + FailureOr parsePatternDecl(); + + /// Check to see if a decl has already been defined with the given name, if + /// one has emit and error and return failure. Returns success otherwise. + LogicalResult checkDefineNamedDecl(const ast::Name &name); + + /// Try to define a variable decl with the given components, returns the + /// variable on success. + FailureOr + defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, + ast::Expr *initExpr, + ArrayRef constraints); + FailureOr + defineVariableDecl(StringRef name, llvm::SMRange nameLoc, ast::Type type, + ArrayRef constraints); + + /// Parse the constraint reference list for a variable decl. + LogicalResult parseVariableDeclConstraintList( + SmallVectorImpl &constraints); + + /// Parse the expression used within a type constraint, e.g. Attr. + FailureOr parseTypeConstraintExpr(); + + /// Try to parse a single reference to a constraint. `typeConstraint` is the + /// location of a previously parsed type constraint for the entity that will + /// be constrained by the parsed constraint. `existingConstraints` are any + /// existing constraints that have already been parsed for the same entity + /// that will be constrained by this constraint. + FailureOr + parseConstraint(Optional &typeConstraint, + ArrayRef existingConstraints); + + //===--------------------------------------------------------------------===// + // Exprs + + FailureOr parseExpr(); + + /// Identifier expressions. + FailureOr parseDeclRefExpr(StringRef name, llvm::SMRange loc); + FailureOr parseIdentifierExpr(); + FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); + FailureOr parseOperationName(); + FailureOr parseWrappedOperationName(); + FailureOr parseUnderscoreExpr(); + + //===--------------------------------------------------------------------===// + // Stmts + + FailureOr parseStmt(bool expectTerminalSemicolon = true); + FailureOr parseCompoundStmt(); + FailureOr parseEraseStmt(); + FailureOr parseLetStmt(); + + //===--------------------------------------------------------------------===// + // Creation+Analysis + //===--------------------------------------------------------------------===// + + //===--------------------------------------------------------------------===// + // Decls + + /// Try to create a pattern decl with the given components, returning the + /// Pattern on success. + FailureOr createPatternDecl(llvm::SMRange loc, + Optional name, + Optional benefit, + bool hasBoundedRecursion, + ast::CompoundStmt *body); + + /// Try to create a variable decl with the given components, returning the + /// Variable on success. + FailureOr + createVariableDecl(StringRef name, llvm::SMRange loc, ast::Expr *initializer, + ArrayRef constraints); + + /// Validate the constraints used to constraint a variable decl. + /// `inferredType` is the type of the variable inferred by the constraints + /// within the list, and is updated to the most refined type as determined by + /// the constraints. Returns success if the constraint list is valid, failure + /// otherwise. + LogicalResult + validateVariableConstraints(ArrayRef constraints, + ast::Type &inferredType); + /// Validate a single reference to a constraint. `inferredType` contains the + /// currently inferred variabled type and is refined within the type defined + /// by the constraint. Returns success if the constraint is valid, failure + /// otherwise. + LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, + ast::Type &inferredType); + LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); + LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); + + //===--------------------------------------------------------------------===// + // Exprs + + FailureOr createDeclRefExpr(llvm::SMRange loc, + ast::Decl *decl); + FailureOr + createInlineVariableExpr(ast::Type type, StringRef name, llvm::SMRange loc, + ArrayRef constraints); + FailureOr + createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, + llvm::SMRange loc); + + /// Validate the member access `name` into the given parent expression. On + /// success, this also returns the type of the member accessed. + FailureOr validateMemberAccess(ast::Expr *parentExpr, + StringRef name, llvm::SMRange loc); + + //===--------------------------------------------------------------------===// + // Stmts + + FailureOr createEraseStmt(llvm::SMRange loc, + ast::Expr *rootOp); + + //===--------------------------------------------------------------------===// + // Lexer Utilities + //===--------------------------------------------------------------------===// + + /// If the current token has the specified kind, consume it and return true. + /// If not, return false. + bool consumeIf(Token::Kind kind) { + if (curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; + } + + /// Advance the current lexer onto the next token. + void consumeToken() { + assert(curToken.isNot(Token::eof, Token::error) && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + + /// Advance the current lexer onto the next token, asserting what the expected + /// current token is. This is preferred to the above method because it leads + /// to more self-documenting code with better checking. + void consumeToken(Token::Kind kind) { + assert(curToken.is(kind) && "consumed an unexpected token"); + consumeToken(); + } + + /// Consume the specified token if present and return success. On failure, + /// output a diagnostic and return failure. + LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + if (curToken.getKind() != kind) + return emitError(curToken.getLoc(), msg); + consumeToken(); + return success(); + } + LogicalResult emitError(llvm::SMRange loc, const Twine &msg) { + lexer.emitError(loc, msg); + return failure(); + } + LogicalResult emitError(const Twine &msg) { + return emitError(curToken.getLoc(), msg); + } + LogicalResult emitErrorAndNote(llvm::SMRange loc, const Twine &msg, + llvm::SMRange noteLoc, const Twine ¬e) { + lexer.emitErrorAndNote(loc, msg, noteLoc, note); + return failure(); + } + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + + /// The owning AST context. + ast::Context &context; + + /// The lexer of this parser. + Lexer lexer; + + /// The current token within the lexer. + Token curToken; + + /// The most recently defined decl scope. + ast::DeclScope *curDeclScope; + llvm::SpecificBumpPtrAllocator scopeAllocator; + + /// The current context of the parser. + ParserContext parserContext = ParserContext::Global; + + /// The include directories of the parser context. + const std::vector &includeDirs; + + /// Cached types to simplify verification and expression creation. + ast::Type valueTy, valueRangeTy; + ast::Type typeTy, typeRangeTy; +}; +} // namespace + +FailureOr Parser::parseModule() { + llvm::SMLoc moduleLoc = curToken.getStartLoc(); + pushDeclScope(); + + // Parse the top-level decls of the module. + SmallVector decls; + if (failed(parseModuleBody(decls))) + return popDeclScope(), failure(); + + popDeclScope(); + return ast::Module::create(context, moduleLoc, decls); +} + +LogicalResult Parser::parseModuleBody(SmallVector &decls) { + while (curToken.isNot(Token::eof)) { + if (curToken.is(Token::directive)) { + if (failed(parseDirective(decls))) + return failure(); + continue; + } + + FailureOr decl = parseTopLevelDecl(); + if (failed(decl)) + return failure(); + decls.push_back(*decl); + } + return success(); +} + +LogicalResult Parser::convertExpressionTo( + ast::Expr *&expr, ast::Type type, + function_ref noteAttachFn) { + ast::Type exprType = expr->getType(); + if (exprType == type) + return success(); + + auto emitConvertError = [&]() -> ast::InFlightDiagnostic { + ast::InFlightDiagnostic diag = context.getDiagEngine().emitError( + expr->getLoc(), llvm::formatv("unable to convert expression of type " + "`{0}` to the expected type of " + "`{1}`", + exprType, type)); + if (noteAttachFn) + noteAttachFn(*diag); + return diag; + }; + + if (auto exprOpType = exprType.dyn_cast()) { + // Two operation types are compatible if they have the same name, or if the + // expected type is more general. + if (auto opType = type.dyn_cast()) { + if (opType.getName()) + return emitConvertError(); + return success(); + } + + // An operation can always convert to a ValueRange. + if (type == valueRangeTy) { + expr = ast::MemberAccessExpr::create(context, expr->getLoc(), expr, + "$results", valueRangeTy); + return success(); + } + + // Allow conversion to a single value by constraining the result range. + if (type == valueTy) { + expr = ast::MemberAccessExpr::create(context, expr->getLoc(), expr, + "$results", valueTy); + return success(); + } + return emitConvertError(); + } + + // FIXME: Decide how to allow/support converting a single result to multiple, + // and multiple to a single result. For now, we just allow Single->Range, + // but this isn't something really supported in the PDL dialect. We should + // figure out some way to support both. + if ((exprType == valueTy || exprType == valueRangeTy) && + (type == valueTy || type == valueRangeTy)) + return success(); + if ((exprType == typeTy || exprType == typeRangeTy) && + (type == typeTy || type == typeRangeTy)) + return success(); + + return emitConvertError(); +} + +//===----------------------------------------------------------------------===// +// Directives + +LogicalResult Parser::parseDirective(SmallVector &decls) { + StringRef directive = curToken.getSpelling(); + if (directive == "#include") + return parseInclude(decls); + + return emitError("unknown directive `" + directive + "`"); +} + +LogicalResult Parser::parseInclude(SmallVector &decls) { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::directive); + + // Parse the file being included. + if (!curToken.isString()) + return emitError(loc, + "expected string file name after `include` directive"); + llvm::SMRange fileLoc = curToken.getLoc(); + std::string filenameStr = curToken.getStringValue(); + StringRef filename = filenameStr; + consumeToken(); + + // Check the type of include. If ending with `.pdll`, this is another pdl file + // to be parsed along with the current module. + if (filename.endswith(".pdll")) { + if (failed(lexer.pushInclude(filename))) + return emitError(fileLoc, + "unable to open include file `" + filename + "`"); + + // If we added the include successfully, parse it into the current module. + // Make sure to save the current token so that we can restore it when we + // finish parsing the nested file. + Token oldToken = curToken; + curToken = lexer.lexToken(); + LogicalResult result = parseModuleBody(decls); + curToken = oldToken; + return result; + } + + return emitError(fileLoc, "expected include filename to end with `.pdll`"); +} + +//===----------------------------------------------------------------------===// +// Decls + +FailureOr Parser::parseTopLevelDecl() { + FailureOr decl; + switch (curToken.getKind()) { + case Token::kw_Pattern: + decl = parsePatternDecl(); + break; + default: + return emitError("expected top-level declaration, such as a `Pattern`"); + } + if (failed(decl)) + return failure(); + + // If the decl has a name, add it to the current scope. + if (const Optional &name = (*decl)->getName()) { + if (failed(checkDefineNamedDecl(*name))) + return failure(); + curDeclScope->add(*decl); + } + return decl; +} + +FailureOr Parser::parsePatternDecl() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_Pattern); + llvm::SaveAndRestore saveCtx(parserContext, + ParserContext::PatternMatch); + + // Check for an optional identifier for the pattern name. + Optional name; + if (curToken.is(Token::identifier)) { + name.emplace(curToken.getSpelling(), curToken.getLoc()); + consumeToken(Token::identifier); + } + + // TODO: Parse any pattern metadata. + Optional benefit; + bool hasBoundedRecursion = false; + + // Parse the pattern body. + ast::CompoundStmt *body; + + if (curToken.isNot(Token::l_brace)) + return emitError("expected `{` to start pattern body"); + FailureOr bodyResult = parseCompoundStmt(); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + + // Verify the body of the pattern. + auto bodyIt = body->begin(), bodyE = body->end(); + for (; bodyIt != bodyE; ++bodyIt) { + // Break when we've found the rewrite statement. + if (isa(*bodyIt)) + break; + } + if (bodyIt == bodyE) { + return emitError(loc, + "expected Pattern body to terminate with an operation " + "rewrite statement, such as `erase`"); + } + if (std::next(bodyIt) != bodyE) { + return emitError((*std::next(bodyIt))->getLoc(), + "Pattern body was terminated by an operation " + "rewrite statement, but found trailing statements"); + } + + return createPatternDecl(loc, name, benefit, hasBoundedRecursion, body); +} + +FailureOr Parser::parseTypeConstraintExpr() { + consumeToken(Token::less); + + FailureOr typeExpr = parseExpr(); + if (failed(typeExpr) || + failed(parseToken(Token::greater, + "expected '>' after variable type constraint"))) + return failure(); + return typeExpr; +} + +LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) { + assert(curDeclScope && "defining decl outside of a decl scope"); + if (ast::Decl *lastDecl = curDeclScope->lookup(name.name)) { + return emitErrorAndNote( + name.location, "`" + name.name + "` has already been defined", + lastDecl->getName()->location, "see previous definition here"); + } + return success(); +} + +FailureOr +Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, + ast::Type type, ast::Expr *initExpr, + ArrayRef constraints) { + assert(curDeclScope && "defining variable outside of decl scope"); + if (name.empty() || name == "_") { + return ast::VariableDecl::create(context, ast::Name(name, nameLoc), type, + initExpr, constraints); + } + ast::Name nameDecl(name, nameLoc); + if (failed(checkDefineNamedDecl(nameDecl))) + return failure(); + + auto *varDecl = + ast::VariableDecl::create(context, nameDecl, type, initExpr, constraints); + curDeclScope->add(varDecl); + return varDecl; +} + +FailureOr +Parser::defineVariableDecl(StringRef name, llvm::SMRange nameLoc, + ast::Type type, + ArrayRef constraints) { + return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, + constraints); +} + +LogicalResult Parser::parseVariableDeclConstraintList( + SmallVectorImpl &constraints) { + Optional typeConstraint; + auto parseSingleConstraint = [&] { + FailureOr constraint = + parseConstraint(typeConstraint, constraints); + if (failed(constraint)) + return failure(); + constraints.push_back(*constraint); + return success(); + }; + + // Check to see if this is a single constraint, or a list. + if (!consumeIf(Token::l_square)) + return parseSingleConstraint(); + + do { + if (failed(parseSingleConstraint())) + return failure(); + } while (consumeIf(Token::comma)); + return parseToken(Token::r_square, "expected `]` after constraint list"); +} + +FailureOr +Parser::parseConstraint(Optional &typeConstraint, + ArrayRef existingConstraints) { + auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { + if (typeConstraint) + return emitErrorAndNote( + curToken.getLoc(), + "the type of this variable has already been constrained", + *typeConstraint, "see previous constraint location here"); + FailureOr constraintExpr = parseTypeConstraintExpr(); + if (failed(constraintExpr)) + return failure(); + typeExpr = *constraintExpr; + typeConstraint = typeExpr->getLoc(); + return success(); + }; + + llvm::SMRange loc = curToken.getLoc(); + switch (curToken.getKind()) { + case Token::kw_Attr: { + consumeToken(Token::kw_Attr); + + // Check for a type constraint. + ast::Expr *typeExpr = nullptr; + if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) + return failure(); + return ast::ConstraintRef( + ast::AttrConstraintDecl::create(context, loc, typeExpr), loc); + } + case Token::kw_Op: { + consumeToken(Token::kw_Op); + + // Parse an optional operation name. + FailureOr opName = parseWrappedOperationName(); + if (failed(opName)) + return failure(); + + return ast::ConstraintRef( + ast::OpConstraintDecl::create(context, loc, *opName), loc); + } + case Token::kw_Type: + consumeToken(Token::kw_Type); + return ast::ConstraintRef(ast::TypeConstraintDecl::create(context, loc), + loc); + case Token::kw_TypeRange: + consumeToken(Token::kw_TypeRange); + return ast::ConstraintRef( + ast::TypeRangeConstraintDecl::create(context, loc), loc); + case Token::kw_Value: { + consumeToken(Token::kw_Value); + + // Check for a type constraint. + ast::Expr *typeExpr = nullptr; + if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) + return failure(); + + return ast::ConstraintRef( + ast::ValueConstraintDecl::create(context, loc, typeExpr), loc); + } + case Token::kw_ValueRange: { + consumeToken(Token::kw_ValueRange); + + // Check for a type constraint. + ast::Expr *typeExpr = nullptr; + if (curToken.is(Token::less) && failed(parseTypeConstraint(typeExpr))) + return failure(); + + return ast::ConstraintRef( + ast::ValueRangeConstraintDecl::create(context, loc, typeExpr), loc); + } + case Token::identifier: { + StringRef constraintName = curToken.getSpelling(); + consumeToken(Token::identifier); + + // Lookup the referenced constraint. + ast::Decl *cstDecl = curDeclScope->lookup(constraintName); + if (!cstDecl) { + return emitError(loc, "unknown reference to constraint `" + + constraintName + "`"); + } + + // Handle a reference to a proper constraint. + if (auto *cst = dyn_cast(cstDecl)) + return ast::ConstraintRef(cst, loc); + + return emitErrorAndNote( + loc, "invalid reference to non-constraint", cstDecl->getLoc(), + "see the definition of `" + constraintName + "` here"); + } + default: + break; + } + return emitError(loc, "expected identifier constraint"); +} + +//===----------------------------------------------------------------------===// +// Exprs + +FailureOr Parser::parseExpr() { + if (curToken.is(Token::underscore)) + return parseUnderscoreExpr(); + + // Parse the LHS expression. + FailureOr lhsExpr; + switch (curToken.getKind()) { + case Token::identifier: + lhsExpr = parseIdentifierExpr(); + break; + default: + return emitError("expected expression"); + } + if (failed(lhsExpr)) + return failure(); + + // Check for an operator expression. + while (true) { + switch (curToken.getKind()) { + case Token::dot: + lhsExpr = parseMemberAccessExpr(*lhsExpr); + break; + default: + return lhsExpr; + } + if (failed(lhsExpr)) + return failure(); + } +} + +FailureOr Parser::parseDeclRefExpr(StringRef name, + llvm::SMRange loc) { + ast::Decl *decl = curDeclScope->lookup(name); + if (!decl) + return emitError(loc, "undefined reference to `" + name + "`"); + + return createDeclRefExpr(loc, decl); +} + +FailureOr Parser::parseIdentifierExpr() { + StringRef name = curToken.getSpelling(); + llvm::SMRange nameLoc = curToken.getLoc(); + consumeToken(); + + // Check to see if this is a decl ref expression that defines a variable + // inline. + if (consumeIf(Token::colon)) { + SmallVector constraints; + if (failed(parseVariableDeclConstraintList(constraints))) + return failure(); + ast::Type type; + if (failed(validateVariableConstraints(constraints, type))) + return failure(); + return createInlineVariableExpr(type, name, nameLoc, constraints); + } + + return parseDeclRefExpr(name, nameLoc); +} + +FailureOr Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::dot); + + // Parse the member name. + Token memberNameTok = curToken; + if (memberNameTok.isNot(Token::identifier, Token::integer) && + !memberNameTok.isKeyword()) + return emitError(loc, "expected identifier or numeric member name"); + StringRef memberName = memberNameTok.getSpelling(); + consumeToken(); + + return createMemberAccessExpr(parentExpr, memberName, loc); +} + +FailureOr Parser::parseOperationName() { + llvm::SMRange loc = curToken.getLoc(); + + // Handle the case of an no operation name. + if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) + return ast::OpNameDecl::create(context, llvm::SMRange()); + + StringRef name = curToken.getSpelling(); + consumeToken(); + + // Otherwise, this is a literal operation name. + if (failed(parseToken(Token::dot, "expected `.` after dialect namespace"))) + return failure(); + + if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) + return emitError("expected operation name after dialect namespace"); + + name = StringRef(name.data(), name.size() + 1); + do { + name = StringRef(name.data(), name.size() + curToken.getSpelling().size()); + loc.End = curToken.getEndLoc(); + consumeToken(); + } while (curToken.isAny(Token::identifier, Token::dot) || + curToken.isKeyword()); + return ast::OpNameDecl::create(context, ast::Name(name, loc)); +} + +FailureOr Parser::parseWrappedOperationName() { + if (!consumeIf(Token::less)) + return ast::OpNameDecl::create(context, llvm::SMRange()); + + FailureOr opNameDecl = parseOperationName(); + if (failed(opNameDecl)) + return failure(); + + if (failed(parseToken(Token::greater, "expected `>` after operation name"))) + return failure(); + return opNameDecl; +} + +FailureOr Parser::parseUnderscoreExpr() { + StringRef name = curToken.getSpelling(); + llvm::SMRange nameLoc = curToken.getLoc(); + consumeToken(Token::underscore); + + // Underscore expressions require a constraint list. + if (failed(parseToken(Token::colon, "expected `:` after `_` variable"))) + return failure(); + + // Parse the constraints for the expression. + SmallVector constraints; + if (failed(parseVariableDeclConstraintList(constraints))) + return failure(); + + ast::Type type; + if (failed(validateVariableConstraints(constraints, type))) + return failure(); + return createInlineVariableExpr(type, name, nameLoc, constraints); +} + +//===----------------------------------------------------------------------===// +// Stmts + +FailureOr Parser::parseStmt(bool expectTerminalSemicolon) { + FailureOr stmt; + switch (curToken.getKind()) { + case Token::kw_erase: + stmt = parseEraseStmt(); + break; + case Token::kw_let: + stmt = parseLetStmt(); + break; + default: + stmt = parseExpr(); + break; + } + if (failed(stmt) || + (expectTerminalSemicolon && + failed(parseToken(Token::semicolon, "expected `;` after statement")))) + return failure(); + return stmt; +} + +FailureOr Parser::parseCompoundStmt() { + llvm::SMLoc startLoc = curToken.getStartLoc(); + consumeToken(Token::l_brace); + + // Push a new block scope and parse any nested statements. + pushDeclScope(); + SmallVector statements; + while (curToken.isNot(Token::r_brace)) { + FailureOr statement = parseStmt(); + if (failed(statement)) + return popDeclScope(), failure(); + statements.push_back(*statement); + } + popDeclScope(); + + // Consume the end brace. + llvm::SMRange location(startLoc, curToken.getEndLoc()); + consumeToken(Token::r_brace); + + return ast::CompoundStmt::create(context, location, statements); +} + +FailureOr Parser::parseEraseStmt() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_erase); + + // Parse the root operation expression. + FailureOr rootOp = parseExpr(); + if (failed(rootOp)) + return failure(); + + return createEraseStmt(loc, *rootOp); +} + +FailureOr Parser::parseLetStmt() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_let); + + // Parse the name of the new variable. + llvm::SMRange varLoc = curToken.getLoc(); + if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) { + // `_` is a reserved variable name. + if (curToken.is(Token::underscore)) { + return emitError(varLoc, + "`_` may only be used to define \"inline\" variables"); + } + return emitError(varLoc, + "expected identifier after `let` to name a new variable"); + } + StringRef varName = curToken.getSpelling(); + consumeToken(); + + // Parse the optional set of constraints. + SmallVector constraints; + if (consumeIf(Token::colon) && + failed(parseVariableDeclConstraintList(constraints))) + return failure(); + + // Parse the optional initializer expression. + ast::Expr *initializer = nullptr; + if (consumeIf(Token::equal)) { + FailureOr initOrFailure = parseExpr(); + if (failed(initOrFailure)) + return failure(); + initializer = *initOrFailure; + + // Check that the constraints are compatible with having an initializer, + // e.g. type constraints cannot be used with initializers. + for (ast::ConstraintRef constraint : constraints) { + LogicalResult result = + TypeSwitch(constraint.constraint) + .Case([&](const auto *cst) { + if (auto *typeConstraintExpr = cst->getTypeExpr()) { + return emitError( + constraint.referenceLoc, + "type constraints are not permitted on variables with " + "initializers"); + } + return success(); + }) + .Default(success()); + if (failed(result)) + return failure(); + } + } + + FailureOr varDecl = + createVariableDecl(varName, varLoc, initializer, constraints); + if (failed(varDecl)) + return failure(); + return ast::LetStmt::create(context, loc, *varDecl); +} + +//===----------------------------------------------------------------------===// +// Creation+Analysis +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Decls + +FailureOr +Parser::createPatternDecl(llvm::SMRange loc, Optional name, + Optional benefit, bool hasBoundedRecursion, + ast::CompoundStmt *body) { + return ast::PatternDecl::create(context, loc, name, benefit, + hasBoundedRecursion, body); +} + +FailureOr +Parser::createVariableDecl(StringRef name, llvm::SMRange loc, + ast::Expr *initializer, + ArrayRef constraints) { + // The type of the variable, which is expected to be inferred by either a + // constraint or an initializer expression. + ast::Type type; + if (failed(validateVariableConstraints(constraints, type))) + return failure(); + + if (initializer) { + // Update the variable type based on the initializer, or try to convert the + // initializer to the existing type. + if (!type) + type = initializer->getType(); + else if (ast::Type mergedType = type.refineWith(initializer->getType())) + type = mergedType; + else if (failed(convertExpressionTo(initializer, type))) + return failure(); + + // Otherwise, if there is no initializer check that the type has already + // been resolved from the constraint list. + } else if (!type) { + return emitErrorAndNote( + loc, "unable to infer type for variable `" + name + "`", loc, + "the type of a variable must be inferable from the constraint " + "list or the initializer"); + } + + // Try to define a variable with the given name. + FailureOr varDecl = + defineVariableDecl(name, loc, type, initializer, constraints); + if (failed(varDecl)) + return failure(); + + return *varDecl; +} + +LogicalResult +Parser::validateVariableConstraints(ArrayRef constraints, + ast::Type &inferredType) { + for (const ast::ConstraintRef &ref : constraints) + if (failed(validateVariableConstraint(ref, inferredType))) + return failure(); + return success(); +} + +LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, + ast::Type &inferredType) { + ast::Type constraintType; + if (const auto *cst = dyn_cast(ref.constraint)) { + if (const ast::Expr *typeExpr = cst->getTypeExpr()) { + if (failed(validateTypeConstraintExpr(typeExpr))) + return failure(); + } + constraintType = ast::AttributeType::get(context); + } else if (const auto *cst = + dyn_cast(ref.constraint)) { + constraintType = ast::OperationType::get(context, cst->getName()); + } else if (isa(ref.constraint)) { + constraintType = typeTy; + } else if (isa(ref.constraint)) { + constraintType = typeRangeTy; + } else if (const auto *cst = + dyn_cast(ref.constraint)) { + if (const ast::Expr *typeExpr = cst->getTypeExpr()) { + if (failed(validateTypeConstraintExpr(typeExpr))) + return failure(); + } + constraintType = valueTy; + } else if (const auto *cst = + dyn_cast(ref.constraint)) { + if (const ast::Expr *typeExpr = cst->getTypeExpr()) { + if (failed(validateTypeRangeConstraintExpr(typeExpr))) + return failure(); + } + constraintType = valueRangeTy; + } else { + llvm_unreachable("unknown constraint type"); + } + + // Check that the constraint type is compatible with the current inferred + // type. + if (!inferredType) { + inferredType = constraintType; + } else if (ast::Type mergedTy = inferredType.refineWith(constraintType)) { + inferredType = mergedTy; + } else { + return emitError(ref.referenceLoc, + llvm::formatv("constraint type `{0}` is incompatible " + "with the previously inferred type `{1}`", + constraintType, inferredType)); + } + return success(); +} + +LogicalResult Parser::validateTypeConstraintExpr(const ast::Expr *typeExpr) { + ast::Type typeExprType = typeExpr->getType(); + if (typeExprType != typeTy) { + return emitError(typeExpr->getLoc(), + "expected expression of `Type` in type constraint"); + } + return success(); +} + +LogicalResult +Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { + ast::Type typeExprType = typeExpr->getType(); + if (typeExprType != typeRangeTy) { + return emitError(typeExpr->getLoc(), + "expected expression of `TypeRange` in type constraint"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Exprs + +FailureOr Parser::createDeclRefExpr(llvm::SMRange loc, + ast::Decl *decl) { + // Check the type of decl being referenced. + ast::Type declType; + if (auto *varDecl = dyn_cast(decl)) + declType = varDecl->getType(); + else + return emitError(loc, + "invalid reference to `" + decl->getName()->name + "`"); + + return ast::DeclRefExpr::create(context, loc, decl, declType); +} + +FailureOr +Parser::createInlineVariableExpr(ast::Type type, StringRef name, + llvm::SMRange loc, + ArrayRef constraints) { + FailureOr decl = + defineVariableDecl(name, loc, type, constraints); + if (failed(decl)) + return failure(); + return ast::DeclRefExpr::create(context, loc, *decl, type); +} + +FailureOr +Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, + llvm::SMRange loc) { + // Validate the member name for the given parent expression. + FailureOr memberType = validateMemberAccess(parentExpr, name, loc); + if (failed(memberType)) + return failure(); + + return ast::MemberAccessExpr::create(context, loc, parentExpr, name, + *memberType); +} + +FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, + StringRef name, + llvm::SMRange loc) { + ast::Type parentType = parentExpr->getType(); + if (ast::OperationType opType = parentType.dyn_cast()) { + // $results is a special member access representing all of the results. + // TODO: Should we have special AST expressions for these? How does the + // user reference these in the language itself? + if (name == "$results") + return valueRangeTy; + } + return emitError( + loc, + llvm::formatv("invalid member access `{0}` on expression of type `{1}`", + name, parentType)); +} + +//===----------------------------------------------------------------------===// +// Stmts + +FailureOr Parser::createEraseStmt(llvm::SMRange loc, + ast::Expr *rootOp) { + // Check that root is an Operation. + ast::Type rootType = rootOp->getType(); + if (!rootType.isa()) + return emitError(rootOp->getLoc(), "expected `Op` expression"); + + return ast::EraseStmt::create(context, loc, rootOp); +} + +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +FailureOr +mlir::pdll::parsePDLAST(ast::Context &context, llvm::SourceMgr &sourceMgr, + const std::vector &includeDirs) { + Parser parser(context, sourceMgr, includeDirs); + return parser.parseModule(); +} diff --git a/mlir/test/mlir-pdll/Parser/directive-failure.pdll b/mlir/test/mlir-pdll/Parser/directive-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/directive-failure.pdll @@ -0,0 +1,23 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +// CHECK: unknown directive `#foo` +#foo + +// ----- + +//===----------------------------------------------------------------------===// +// Include +//===----------------------------------------------------------------------===// + +// CHECK: expected string file name after `include` directive +#include <> + +// ----- + +// CHECK: unable to open include file `unknown_file.pdll` +#include "unknown_file.pdll" + +// ----- + +// CHECK: expected include filename to end with `.pdll` +#include "unknown_file.foo" diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -0,0 +1,62 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +//===----------------------------------------------------------------------===// +// Reference Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected identifier constraint + let foo = Foo: ; +} + +// ----- + +Pattern { + // CHECK: undefined reference to `bar` + let foo = bar; +} + +// ----- + +Pattern FooPattern { + erase _: Op; +} + +Pattern { + // CHECK: invalid reference to `FooPattern` + let foo = FooPattern; +} + +// ----- + +Pattern { + // CHECK: expected `:` after `_` variable + let foo = _; +} + +// ----- + +Pattern { + // CHECK: expected identifier constraint + let foo = _: ; +} + +// ----- + +//===----------------------------------------------------------------------===// +// Member Access Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected identifier or numeric member name + let root: Op; + erase root.<>; +} + +// ----- + +Pattern { + // CHECK: invalid member access `unknown_result` on expression of type `Op` + let root: Op; + erase root.unknown_result; +} diff --git a/mlir/test/mlir-pdll/Parser/include.pdll b/mlir/test/mlir-pdll/Parser/include.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include.pdll @@ -0,0 +1,15 @@ +// RUN: mlir-pdll %s -I %S | FileCheck %s + +Pattern BeforeIncludedPattern { + erase _: Op; +} + +#include "include/included.pdll" + +Pattern AfterIncludedPattern { + erase _: Op; +} + +// CHECK: PatternDecl {{.*}} Name +// CHECK: PatternDecl {{.*}} Name +// CHECK: PatternDecl {{.*}} Name diff --git a/mlir/test/mlir-pdll/Parser/include/included.pdll b/mlir/test/mlir-pdll/Parser/include/included.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include/included.pdll @@ -0,0 +1,5 @@ +// This file is included by 'include.pdll' as part of testing include files. + +Pattern IncludedPattern { + erase _: Op; +} diff --git a/mlir/test/mlir-pdll/Parser/lit.local.cfg b/mlir/test/mlir-pdll/Parser/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/lit.local.cfg @@ -0,0 +1,2 @@ +config.suffixes = ['.pdll'] +config.excludes = ['include'] diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll @@ -0,0 +1,26 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +// CHECK: expected `{` to start pattern body +Pattern } + +// ----- + +// CHECK: :6:9: error: `Foo` has already been defined +// CHECK: :5:9: note: see previous definition here +Pattern Foo { erase root: Op; } +Pattern Foo { erase root: Op; } + +// ----- + +// CHECK: expected Pattern body to terminate with an operation rewrite statement +Pattern { + let value: Value; +} + +// ----- + +// CHECK: Pattern body was terminated by an operation rewrite statement, but found trailing statements +Pattern { + erase root: Op; + let value: Value; +} diff --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/pattern.pdll @@ -0,0 +1,17 @@ +// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s + +// CHECK: Module +// CHECK: `-PatternDecl +// CHECK: `-CompoundStmt +// CHECK: `-EraseStmt +Pattern { + erase _: Op; +} + +// ----- + +// CHECK: Module +// CHECK: `-PatternDecl {{.*}} Name +Pattern NamedPattern { + erase _: Op; +} diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -0,0 +1,222 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +// CHECK: expected top-level declaration, such as a `Pattern` +10 + +// ----- + +Pattern { + // CHECK: expected `;` after statement + erase _: Op +} + +// ----- + +//===----------------------------------------------------------------------===// +// `erase` +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected expression + erase; +} + +// ----- + +Pattern { + // CHECK: expected `Op` expression + erase _: Attr; +} + +// ----- + +//===----------------------------------------------------------------------===// +// `let` +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected identifier after `let` to name a new variable + let 5; +} + +// ----- + +Pattern { + // CHECK: `_` may only be used to define "inline" variables + let _; +} + +// ----- + +Pattern { + // CHECK: expected expression + let foo: Attr<>; +} + +// ----- + +Pattern { + // CHECK: expected expression of `Type` in type constraint + let foo: Attr<_: Attr>; +} + +// ----- + +Pattern { + // CHECK: expected '>' after variable type constraint + let foo: Attr<_: Type{}; +} + +// ----- + +Pattern { + // CHECK: the type of this variable has already been constrained + let foo: [Attr<_: Type>, Attr<_: Type]; +} + +// ----- + +Pattern { + // CHECK: expected `.` after dialect namespace + let foo: Op; +} + +// ----- + +Pattern { + // CHECK: expected operation name after dialect namespace + let foo: Op; +} + +// ----- + +Pattern { + // CHECK: expected `>` after operation name + let foo: Op; +} + +// ----- + +Pattern { + // CHECK: expected expression of `Type` in type constraint + let foo: Value<_: Attr>; +} + +// ----- + +Pattern { + // CHECK: expected '>' after variable type constraint + let foo: Value<_: Type{}; +} + +// ----- + +Pattern { + // CHECK: the type of this variable has already been constrained + let foo: [Value<_: Type>, Value<_: Type]; +} + +// ----- + +Pattern { + // CHECK: expected expression + let foo: ValueRange<10>; +} + +// ----- + +Pattern { + // CHECK: expected expression of `TypeRange` in type constraint + let foo: ValueRange<_: Type>; +} + +// ----- + +Pattern { + // CHECK: expected '>' after variable type constraint + let foo: ValueRange<_: Type{}; +} + +// ----- + +Pattern { + // CHECK: the type of this variable has already been constrained + let foo: [ValueRange<_: Type>, ValueRange<_: Type]; +} + +// ----- + +Pattern { + // CHECK: unknown reference to constraint `UnknownConstraint` + let foo: UnknownConstraint; +} + +// ----- + +Pattern Foo { + erase root: Op; +} + +Pattern { + // CHECK: invalid reference to non-constraint + let foo: Foo; +} + +// ----- + +Pattern { + // CHECK: constraint type `Attr` is incompatible with the previously inferred type `Value` + let foo: [Value, Attr]; +} + +// ----- + +Pattern { + // CHECK: expected `]` after constraint list + let foo: [Attr[]; +} + +// ----- + +Pattern { + // CHECK: expected expression + let foo: Attr = ; +} + +// ----- + +Pattern { + // CHECK: type constraints are not permitted on variables with initializers + let foo: ValueRange<_: Type> = _: Op; +} + +// ----- + +Pattern { + // CHECK: unable to infer type for variable `foo` + // CHECK: note: the type of a variable must be inferable from the constraint list or the initializer + let foo; +} + +// ----- + +Pattern { + // CHECK: unable to convert expression of type `Attr` to the expected type of `Value` + let foo: Value = _: Attr; +} + +// ----- + +Pattern { + // CHECK: :7:7: error: `foo` has already been defined + // CHECK: :6:7: note: see previous definition here + let foo: Attr; + let foo: Attr; +} diff --git a/mlir/test/mlir-pdll/Parser/stmt.pdll b/mlir/test/mlir-pdll/Parser/stmt.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/stmt.pdll @@ -0,0 +1,155 @@ +// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s + +//===----------------------------------------------------------------------===// +// CompoundStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: CompoundStmt +// CHECK: |-LetStmt +// CHECK: `-EraseStmt +Pattern { + let root: Op; + erase root; +} + +// ----- + +//===----------------------------------------------------------------------===// +// EraseStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: EraseStmt +// CHECK: `-DeclRefExpr {{.*}} Type +Pattern { + erase _: Op; +} + +// ----- + +//===----------------------------------------------------------------------===// +// LetStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-AttrConstraintDecl +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl +Pattern { + let attrVar: Attr; + let var: Op; + erase var; +} + +// ----- + +// Check for proper refinement between constraint types. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type> +// CHECK: `Constraints` +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl {{.*}} Name +Pattern { + let var: [Op, Op]; + erase var; +} + +// ----- + +// Check for proper conversion between initializer and constraint type. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type> +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name +// CHECK: `Constraints` +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl +Pattern { + let input: Op; + let var: Op = input; + erase var; +} + +// ----- + +// Check for proper conversion between initializer and constraint type. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name +// CHECK: `Constraints` +// CHECK: `-ValueConstraintDecl +Pattern { + let input: Op; + let var: Value = input; + erase _: Op; +} + +// ----- + +// Check for proper conversion between initializer and constraint type. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name +// CHECK: `Constraints` +// CHECK: `-ValueRangeConstraintDecl +Pattern { + let input: Op; + let var: ValueRange = input; + erase _: Op; +} + +// ----- + +// Check for proper handling of type constraints. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-ValueConstraintDecl +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-VariableDecl {{.*}} Name<_> Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl +Pattern { + let var: Value<_: Type>; + erase _: Op; +} + +// ----- + +// Check for proper handling of type constraints. + +// CHECK: Module +// CHECK: LetStmt +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-ValueRangeConstraintDecl +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-VariableDecl {{.*}} Name<_> Type +// CHECK: `Constraints` +// CHECK: `-TypeRangeConstraintDecl +Pattern { + let var: ValueRange<_: TypeRange>; + erase _: Op; +} diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -0,0 +1,112 @@ +//===- mlir-pdll.cpp - MLIR PDLL frontend -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/ToolUtilities.h" +#include "mlir/Tools/PDLL/AST/Context.h" +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/Parser/Parser.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; +using namespace mlir::pdll; + +//===----------------------------------------------------------------------===// +// main +//===----------------------------------------------------------------------===// + +/// The desired output type. +enum class OutputType { + AST, +}; + +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr chunkBuffer, + OutputType outputType, std::vector &includeDirs) { + llvm::SourceMgr sourceMgr; + sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), llvm::SMLoc()); + + ast::Context astContext; + FailureOr module = + parsePDLAST(astContext, sourceMgr, includeDirs); + if (failed(module)) + return failure(); + + switch (outputType) { + case OutputType::AST: + (*module)->print(os); + break; + } + + return success(); +} + +int main(int argc, char **argv) { + llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), + llvm::cl::value_desc("filename")); + + llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + + llvm::cl::list includeDirs( + "I", llvm::cl::desc("Directory of include files"), + llvm::cl::value_desc("directory"), llvm::cl::Prefix); + + llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and process each " + "chunk independently"), + llvm::cl::init(false)); + llvm::cl::opt outputType( + "x", llvm::cl::init(OutputType::AST), + llvm::cl::desc("The type of output desired"), + llvm::cl::values(clEnumValN(OutputType::AST, "ast", + "generate the AST for the input file"))); + + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend"); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr inputFile = + openInputFile(inputFilename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + // Set up the output file. + std::unique_ptr outputFile = + openOutputFile(outputFilename, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + // The split-input-file mode is a very specific mode that slices the file + // up into small pieces and checks each independently. + auto processFn = [&](std::unique_ptr chunkBuffer, + raw_ostream &os) { + return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs); + }; + if (splitInputFile) { + if (failed(splitAndProcessBuffer(std::move(inputFile), processFn, + outputFile->os()))) + return 1; + } else if (failed(processFn(std::move(inputFile), outputFile->os()))) { + return 1; + } + outputFile->keep(); + return 0; +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7830,3 +7830,53 @@ "//mlir/test:TestDialect", ], ) + +cc_library( + name = "PDLLAST", + srcs = glob( + [ + "lib/Tools/PDLL/AST/*.cpp", + "lib/Tools/PDLL/AST/*.h", + ], + ), + hdrs = glob(["include/mlir/Tools/PDLL/AST/*.h"]), + includes = ["include"], + deps = [ + "//llvm:Support", + "//llvm:TableGen", + "//mlir:Support", + ], +) + +cc_library( + name = "PDLLParser", + srcs = glob( + [ + "lib/Tools/PDLL/Parser/*.cpp", + "lib/Tools/PDLL/Parser/*.h", + ], + ), + hdrs = glob(["include/mlir/Tools/PDLL/Parser/*.h"]), + includes = ["include"], + deps = [ + ":PDLLAST", + ":Support", + ":TableGen", + "//llvm:Support", + "//llvm:TableGen", + ], +) + +cc_binary( + name = "mlir-pdll", + srcs = [ + "tools/mlir-pdll/mlir-pdll.cpp", + ], + deps = [ + ":PDLLAST", + ":PDLLParser", + ":Support", + "//llvm:Support", + "//llvm:config", + ], +)