diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -300,6 +300,31 @@ CompoundStmt *rewriteBody; }; +//===----------------------------------------------------------------------===// +// ReturnStmt +//===----------------------------------------------------------------------===// + +/// This statement represents a return from a "callable" like decl, e.g. a +/// Constraint or a Rewrite. +class ReturnStmt final : public Node::NodeBase { +public: + static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr); + + /// Return the result expression of this statement. + Expr *getResultExpr() { return resultExpr; } + const Expr *getResultExpr() const { return resultExpr; } + + /// Set the result expression of this statement. + void setResultExpr(Expr *expr) { resultExpr = expr; } + +private: + ReturnStmt(SMRange loc, Expr *resultExpr) + : Base(loc), resultExpr(resultExpr) {} + + // The result expression of this statement. + Expr *resultExpr; +}; + //===----------------------------------------------------------------------===// // Expr //===----------------------------------------------------------------------===// @@ -345,6 +370,43 @@ StringRef value; }; +//===----------------------------------------------------------------------===// +// CallExpr +//===----------------------------------------------------------------------===// + +/// This expression represents a call to a decl, such as a +/// UserConstraintDecl/UserRewriteDecl. +class CallExpr final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static CallExpr *create(Context &ctx, SMRange loc, Expr *callable, + ArrayRef arguments, Type resultType); + + /// Return the callable of this call. + Expr *getCallableExpr() const { return callable; } + + /// Return the arguments of this call. + MutableArrayRef getArguments() { + return {getTrailingObjects(), numArgs}; + } + ArrayRef getArguments() const { + return const_cast(this)->getArguments(); + } + +private: + CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs) + : Base(loc, type), callable(callable), numArgs(numArgs) {} + + /// The callable of this call. + Expr *callable; + + /// The number of arguments of the call. + unsigned numArgs; + + /// TrailingObject utilities. + friend llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // DeclRefExpr //===----------------------------------------------------------------------===// @@ -738,6 +800,114 @@ Expr *typeExpr; }; +//===----------------------------------------------------------------------===// +// UserConstraintDecl +//===----------------------------------------------------------------------===// + +/// This decl represents a user defined constraint. This is either: +/// * an imported native constraint +/// - Similar to an external function declaration. This is a native +/// constraint defined externally, and imported into PDLL via a +/// declaration. +/// * a native constraint defined in PDLL +/// - This is a native constraint, i.e. a constraint whose implementation is +/// defined in C++(or potentially some other non-PDLL language). The +/// implementation of this constraint is specified as a string code block +/// in PDLL. +/// * a PDLL constraint +/// - This is a constraint which is defined using only PDLL constructs. +class UserConstraintDecl final + : public Node::NodeBase, + llvm::TrailingObjects { +public: + /// Create a native constraint with the given optional code block. + static UserConstraintDecl *createNative(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + Optional codeBlock, + Type resultType) { + return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr, + resultType); + } + + /// Create a PDLL constraint with the given body. + static UserConstraintDecl *createPDLL(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + const CompoundStmt *body, + Type resultType) { + return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None, + body, resultType); + } + + /// Return the name of the constraint. + const Name &getName() const { return *Decl::getName(); } + + /// Return the input arguments of this constraint. + MutableArrayRef getInputs() { + return {getTrailingObjects(), numInputs}; + } + ArrayRef getInputs() const { + return const_cast(this)->getInputs(); + } + + /// Return the explicit results of the constraint declaration. May be empty, + /// even if the constraint has results (e.g. in the case of inferred results). + MutableArrayRef getResults() { + return {getTrailingObjects() + numInputs, numResults}; + } + ArrayRef getResults() const { + return const_cast(this)->getResults(); + } + + /// Return the optional code block of this constraint, if this is a native + /// constraint with a provided implementation. + Optional getCodeBlock() const { return codeBlock; } + + /// Return the body of this constraint if this constraint is a PDLL + /// constraint, otherwise returns nullptr. + const CompoundStmt *getBody() const { return constraintBody; } + + /// Return the result type of this constraint. + Type getResultType() const { return resultType; } + + /// Returns true if this constraint is external. + bool isExternal() const { return !constraintBody && !codeBlock; } + +private: + /// Create either a PDLL constraint or a native constraint with the given + /// components. + static UserConstraintDecl * + createImpl(Context &ctx, const Name &name, ArrayRef inputs, + ArrayRef results, Optional codeBlock, + const CompoundStmt *body, Type resultType); + + UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults, + Optional codeBlock, const CompoundStmt *body, + Type resultType) + : Base(name.getLoc(), &name), numInputs(numInputs), + numResults(numResults), codeBlock(codeBlock), constraintBody(body), + resultType(resultType) {} + + /// The number of inputs to this constraint. + unsigned numInputs; + + /// The number of explicit results to this constraint. + unsigned numResults; + + /// The optional code block of this constraint. + Optional codeBlock; + + /// The optional body of this constraint. + const CompoundStmt *constraintBody; + + /// The result type of the constraint. + Type resultType; + + /// Allow access to various internals. + friend llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // NamedAttributeDecl //===----------------------------------------------------------------------===// @@ -826,6 +996,149 @@ const CompoundStmt *patternBody; }; +//===----------------------------------------------------------------------===// +// UserRewriteDecl +//===----------------------------------------------------------------------===// + +/// This decl represents a user defined rewrite. This is either: +/// * an imported native rewrite +/// - Similar to an external function declaration. This is a native +/// rewrite defined externally, and imported into PDLL via a declaration. +/// * a native rewrite defined in PDLL +/// - This is a native rewrite, i.e. a rewrite whose implementation is +/// defined in C++(or potentially some other non-PDLL language). The +/// implementation of this rewrite is specified as a string code block +/// in PDLL. +/// * a PDLL rewrite +/// - This is a rewrite which is defined using only PDLL constructs. +class UserRewriteDecl final + : public Node::NodeBase, + llvm::TrailingObjects { +public: + /// Create a native rewrite with the given optional code block. + static UserRewriteDecl *createNative(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + Optional codeBlock, + Type resultType) { + return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr, + resultType); + } + + /// Create a PDLL rewrite with the given body. + static UserRewriteDecl *createPDLL(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + const CompoundStmt *body, + Type resultType) { + return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None, + body, resultType); + } + + /// Return the name of the rewrite. + const Name &getName() const { return *Decl::getName(); } + + /// Return the input arguments of this rewrite. + MutableArrayRef getInputs() { + return {getTrailingObjects(), numInputs}; + } + ArrayRef getInputs() const { + return const_cast(this)->getInputs(); + } + + /// Return the explicit results of the rewrite declaration. May be empty, + /// even if the rewrite has results (e.g. in the case of inferred results). + MutableArrayRef getResults() { + return {getTrailingObjects() + numInputs, numResults}; + } + ArrayRef getResults() const { + return const_cast(this)->getResults(); + } + + /// Return the optional code block of this rewrite, if this is a native + /// rewrite with a provided implementation. + Optional getCodeBlock() const { return codeBlock; } + + /// Return the body of this rewrite if this rewrite is a PDLL rewrite, + /// otherwise returns nullptr. + const CompoundStmt *getBody() const { return rewriteBody; } + + /// Return the result type of this rewrite. + Type getResultType() const { return resultType; } + + /// Returns true if this rewrite is external. + bool isExternal() const { return !rewriteBody && !codeBlock; } + +private: + /// Create either a PDLL rewrite or a native rewrite with the given + /// components. + static UserRewriteDecl *createImpl(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + Optional codeBlock, + const CompoundStmt *body, Type resultType); + + UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults, + Optional codeBlock, const CompoundStmt *body, + Type resultType) + : Base(name.getLoc(), &name), numInputs(numInputs), + numResults(numResults), codeBlock(codeBlock), rewriteBody(body), + resultType(resultType) {} + + /// The number of inputs to this rewrite. + unsigned numInputs; + + /// The number of explicit results to this rewrite. + unsigned numResults; + + /// The optional code block of this rewrite. + Optional codeBlock; + + /// The optional body of this rewrite. + const CompoundStmt *rewriteBody; + + /// The result type of the rewrite. + Type resultType; + + /// Allow access to various internals. + friend llvm::TrailingObjects; +}; + +//===----------------------------------------------------------------------===// +// CallableDecl +//===----------------------------------------------------------------------===// + +/// This decl represents a shared interface for all callable decls. +class CallableDecl : public Decl { +public: + /// Return the callable type of this decl. + StringRef getCallableType() const { + if (isa(this)) + return "constraint"; + assert(isa(this) && "unknown callable type"); + return "rewrite"; + } + + /// Return the inputs of this decl. + ArrayRef getInputs() const { + if (const auto *cst = dyn_cast(this)) + return cst->getInputs(); + return cast(this)->getInputs(); + } + + /// Return the result type of this decl. + Type getResultType() const { + if (const auto *cst = dyn_cast(this)) + return cst->getResultType(); + return cast(this)->getResultType(); + } + + /// Support LLVM type casting facilities. + static bool classof(const Node *decl) { + return isa(decl); + } +}; + //===----------------------------------------------------------------------===// // VariableDecl //===----------------------------------------------------------------------===// @@ -912,11 +1225,11 @@ inline bool Decl::classof(const Node *node) { return isa(node); + UserRewriteDecl, VariableDecl>(node); } inline bool ConstraintDecl::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool CoreConstraintDecl::classof(const Node *node) { diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h --- a/mlir/include/mlir/Tools/PDLL/AST/Types.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h @@ -22,6 +22,7 @@ struct ConstraintTypeStorage; struct OperationTypeStorage; struct RangeTypeStorage; +struct RewriteTypeStorage; struct TupleTypeStorage; struct TypeTypeStorage; struct ValueTypeStorage; @@ -203,6 +204,20 @@ static ValueRangeType get(Context &context); }; +//===----------------------------------------------------------------------===// +// RewriteType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL type that corresponds to a rewrite reference. +/// This type has no MLIR C++ API correspondance. +class RewriteType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Rewrite type. + static RewriteType get(Context &context); +}; + //===----------------------------------------------------------------------===// // TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -15,6 +15,7 @@ Context::Context() { typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); + typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -76,9 +76,11 @@ void printImpl(const EraseStmt *stmt); void printImpl(const LetStmt *stmt); void printImpl(const ReplaceStmt *stmt); + void printImpl(const ReturnStmt *stmt); void printImpl(const RewriteStmt *stmt); void printImpl(const AttributeExpr *expr); + void printImpl(const CallExpr *expr); void printImpl(const DeclRefExpr *expr); void printImpl(const MemberAccessExpr *expr); void printImpl(const OperationExpr *expr); @@ -89,11 +91,13 @@ void printImpl(const OpConstraintDecl *decl); void printImpl(const TypeConstraintDecl *decl); void printImpl(const TypeRangeConstraintDecl *decl); + void printImpl(const UserConstraintDecl *decl); void printImpl(const ValueConstraintDecl *decl); void printImpl(const ValueRangeConstraintDecl *decl); void printImpl(const NamedAttributeDecl *decl); void printImpl(const OpNameDecl *decl); void printImpl(const PatternDecl *decl); + void printImpl(const UserRewriteDecl *decl); void printImpl(const VariableDecl *decl); void printImpl(const Module *module); @@ -135,6 +139,7 @@ print(type.getElementType()); os << "Range"; }) + .Case([&](RewriteType) { os << "Rewrite"; }) .Case([&](TupleType type) { os << "Tuple<"; llvm::interleaveComma( @@ -160,17 +165,19 @@ .Case< // Statements. const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, - const RewriteStmt, + const ReturnStmt, const RewriteStmt, // Expressions. - const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, - const OperationExpr, const TupleExpr, const TypeExpr, + const AttributeExpr, const CallExpr, const DeclRefExpr, + const MemberAccessExpr, const OperationExpr, const TupleExpr, + const TypeExpr, // Decls. const AttrConstraintDecl, const OpConstraintDecl, const TypeConstraintDecl, const TypeRangeConstraintDecl, - const ValueConstraintDecl, const ValueRangeConstraintDecl, - const NamedAttributeDecl, const OpNameDecl, const PatternDecl, + const UserConstraintDecl, const ValueConstraintDecl, + const ValueRangeConstraintDecl, const NamedAttributeDecl, + const OpNameDecl, const PatternDecl, const UserRewriteDecl, const VariableDecl, const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) @@ -199,6 +206,11 @@ printChildren("ReplValues", stmt->getReplExprs()); } +void NodePrinter::printImpl(const ReturnStmt *stmt) { + os << "ReturnStmt " << stmt << "\n"; + printChildren(stmt->getResultExpr()); +} + void NodePrinter::printImpl(const RewriteStmt *stmt) { os << "RewriteStmt " << stmt << "\n"; printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody()); @@ -208,6 +220,14 @@ os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } +void NodePrinter::printImpl(const CallExpr *expr) { + os << "CallExpr " << expr << " Type<"; + print(expr->getType()); + os << ">\n"; + printChildren(expr->getCallableExpr()); + printChildren("Arguments", expr->getArguments()); +} + void NodePrinter::printImpl(const DeclRefExpr *expr) { os << "DeclRefExpr " << expr << " Type<"; print(expr->getType()); @@ -265,6 +285,21 @@ os << "TypeRangeConstraintDecl " << decl << "\n"; } +void NodePrinter::printImpl(const UserConstraintDecl *decl) { + os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName() + << "> ResultType<" << decl->getResultType() << ">"; + if (Optional codeBlock = decl->getCodeBlock()) { + os << " Code<"; + llvm::printEscapedString(*codeBlock, os); + os << ">"; + } + os << "\n"; + printChildren("Inputs", decl->getInputs()); + printChildren("Results", decl->getResults()); + if (const CompoundStmt *body = decl->getBody()) + printChildren(body); +} + void NodePrinter::printImpl(const ValueConstraintDecl *decl) { os << "ValueConstraintDecl " << decl << "\n"; if (const auto *typeExpr = decl->getTypeExpr()) @@ -303,6 +338,21 @@ printChildren(decl->getBody()); } +void NodePrinter::printImpl(const UserRewriteDecl *decl) { + os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName() + << "> ResultType<" << decl->getResultType() << ">"; + if (Optional codeBlock = decl->getCodeBlock()) { + os << " Code<"; + llvm::printEscapedString(*codeBlock, os); + os << ">"; + } + os << "\n"; + printChildren("Inputs", decl->getInputs()); + printChildren("Results", decl->getResults()); + if (const CompoundStmt *body = decl->getBody()) + printChildren(body); +} + void NodePrinter::printImpl(const VariableDecl *decl) { os << "VariableDecl " << decl << " Name<" << decl->getName().getName() << "> Type<"; diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -108,6 +108,15 @@ RewriteStmt(loc, rootOp, rewriteBody); } +//===----------------------------------------------------------------------===// +// ReturnStmt +//===----------------------------------------------------------------------===// + +ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) { + return new (ctx.getAllocator().Allocate()) + ReturnStmt(loc, resultExpr); +} + //===----------------------------------------------------------------------===// // AttributeExpr //===----------------------------------------------------------------------===// @@ -118,6 +127,22 @@ AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); } +//===----------------------------------------------------------------------===// +// CallExpr +//===----------------------------------------------------------------------===// + +CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, + ArrayRef arguments, Type resultType) { + unsigned allocSize = CallExpr::totalSizeToAlloc(arguments.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); + + CallExpr *expr = + new (rawData) CallExpr(loc, resultType, callable, arguments.size()); + std::uninitialized_copy(arguments.begin(), arguments.end(), + expr->getArguments().begin()); + return expr; +} + //===----------------------------------------------------------------------===// // DeclRefExpr //===----------------------------------------------------------------------===// @@ -267,6 +292,30 @@ ValueRangeConstraintDecl(loc, typeExpr); } +//===----------------------------------------------------------------------===// +// UserConstraintDecl +//===----------------------------------------------------------------------===// + +UserConstraintDecl *UserConstraintDecl::createImpl( + Context &ctx, const Name &name, ArrayRef inputs, + ArrayRef results, Optional codeBlock, + const CompoundStmt *body, Type resultType) { + unsigned allocSize = UserConstraintDecl::totalSizeToAlloc( + inputs.size() + results.size()); + void *rawData = + ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); + if (codeBlock) + codeBlock = codeBlock->copy(ctx.getAllocator()); + + UserConstraintDecl *decl = new (rawData) UserConstraintDecl( + name, inputs.size(), results.size(), codeBlock, body, resultType); + std::uninitialized_copy(inputs.begin(), inputs.end(), + decl->getInputs().begin()); + std::uninitialized_copy(results.begin(), results.end(), + decl->getResults().begin()); + return decl; +} + //===----------------------------------------------------------------------===// // NamedAttributeDecl //===----------------------------------------------------------------------===// @@ -300,6 +349,32 @@ PatternDecl(loc, name, benefit, hasBoundedRecursion, body); } +//===----------------------------------------------------------------------===// +// UserRewriteDecl +//===----------------------------------------------------------------------===// + +UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name, + ArrayRef inputs, + ArrayRef results, + Optional codeBlock, + const CompoundStmt *body, + Type resultType) { + unsigned allocSize = UserRewriteDecl::totalSizeToAlloc( + inputs.size() + results.size()); + void *rawData = + ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl)); + if (codeBlock) + codeBlock = codeBlock->copy(ctx.getAllocator()); + + UserRewriteDecl *decl = new (rawData) UserRewriteDecl( + name, inputs.size(), results.size(), codeBlock, body, resultType); + std::uninitialized_copy(inputs.begin(), inputs.end(), + decl->getInputs().begin()); + std::uninitialized_copy(results.begin(), results.end(), + decl->getResults().begin()); + return decl; +} + //===----------------------------------------------------------------------===// // VariableDecl //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h --- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h +++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h @@ -93,6 +93,12 @@ using Base::Base; }; +//===----------------------------------------------------------------------===// +// RewriteType +//===----------------------------------------------------------------------===// + +struct RewriteTypeStorage : public TypeStorageBase {}; + //===----------------------------------------------------------------------===// // TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -107,6 +107,14 @@ .cast(); } +//===----------------------------------------------------------------------===// +// RewriteType +//===----------------------------------------------------------------------===// + +RewriteType RewriteType::get(Context &context) { + return context.getTypeUniquer().get(); +} + //===----------------------------------------------------------------------===// // TupleType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -55,7 +55,9 @@ kw_OpName, kw_Pattern, kw_replace, + kw_return, kw_rewrite, + kw_Rewrite, kw_Type, kw_TypeRange, kw_Value, diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -298,7 +298,9 @@ .Case("OpName", Token::kw_OpName) .Case("Pattern", Token::kw_Pattern) .Case("replace", Token::kw_replace) + .Case("return", Token::kw_return) .Case("rewrite", Token::kw_rewrite) + .Case("Rewrite", Token::kw_Rewrite) .Case("type", Token::kw_type) .Case("Type", Token::kw_Type) .Case("TypeRange", Token::kw_TypeRange) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -50,6 +50,9 @@ enum class ParserContext { /// The parser is in the global context. Global, + /// The parser is currently within a Constraint, which disallows all types + /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.). + Constraint, /// The parser is currently within the matcher portion of a Pattern, which /// is allows a terminal operation rewrite statement but no other rewrite /// transformations. @@ -106,6 +109,77 @@ FailureOr parseTopLevelDecl(); FailureOr parseNamedAttributeDecl(); + + /// Parse an argument variable as part of the signature of a + /// UserConstraintDecl or UserRewriteDecl. + FailureOr parseArgumentDecl(); + + /// Parse a result variable as part of the signature of a UserConstraintDecl + /// or UserRewriteDecl. + FailureOr parseResultDecl(unsigned resultNum); + + /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being + /// defined in a non-global context. + FailureOr + parseUserConstraintDecl(bool isInline = false); + + /// Parse an inline UserConstraintDecl. An inline decl is one defined in a + /// non-global context, such as within a Pattern/Constraint/etc. + FailureOr parseInlineUserConstraintDecl(); + + /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using + /// PDLL constructs. + FailureOr parseUserPDLLConstraintDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, ast::DeclScope *argumentScope, + ArrayRef results, ast::Type resultType); + + /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being + /// defined in a non-global context. + FailureOr parseUserRewriteDecl(bool isInline = false); + + /// Parse an inline UserRewriteDecl. An inline decl is one defined in a + /// non-global context, such as within a Pattern/Rewrite/etc. + FailureOr parseInlineUserRewriteDecl(); + + /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using + /// PDLL constructs. + FailureOr parseUserPDLLRewriteDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, ast::DeclScope *argumentScope, + ArrayRef results, ast::Type resultType); + + /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have + /// effectively the same syntax, and only differ on slight semantics (given + /// the different parsing contexts). + template + FailureOr parseUserConstraintOrRewriteDecl( + ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, + StringRef anonymousNamePrefix, bool isInline); + + /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl. + /// These decls have effectively the same syntax. + template + FailureOr parseUserNativeConstraintOrRewriteDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, + ArrayRef results, ast::Type resultType); + + /// Parse the functional signature (i.e. the arguments and results) of a + /// UserConstraintDecl or UserRewriteDecl. + LogicalResult parseUserConstraintOrRewriteSignature( + SmallVectorImpl &arguments, + SmallVectorImpl &results, + ast::DeclScope *&argumentScope, ast::Type &resultType); + + /// Validate the return (which if present is specified by bodyIt) of a + /// UserConstraintDecl or UserRewriteDecl. + LogicalResult validateUserConstraintOrRewriteReturn( + StringRef declType, ast::CompoundStmt *body, + ArrayRef::iterator bodyIt, + ArrayRef::iterator bodyE, + ArrayRef results, ast::Type &resultType); + FailureOr parseLambdaBody(function_ref processStatementFn, bool expectTerminalSemicolon = true); @@ -138,10 +212,17 @@ /// location of a previously parsed type constraint for the entity that will /// be constrained by the parsed constraint. `existingConstraints` are any /// existing constraints that have already been parsed for the same entity - /// that will be constrained by this constraint. + /// that will be constrained by this constraint. `allowInlineTypeConstraints` + /// allows the use of inline Type constraints, e.g. `Value`. FailureOr parseConstraint(Optional &typeConstraint, - ArrayRef existingConstraints); + ArrayRef existingConstraints, + bool allowInlineTypeConstraints); + + /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl + /// argument or result variable. The constraints for these variables do not + /// allow inline type constraints, and only permit a single constraint. + FailureOr parseArgOrResultConstraint(); //===--------------------------------------------------------------------===// // Exprs @@ -150,8 +231,11 @@ /// Identifier expressions. FailureOr parseAttributeExpr(); + FailureOr parseCallExpr(ast::Expr *parentExpr); FailureOr parseDeclRefExpr(StringRef name, SMRange loc); FailureOr parseIdentifierExpr(); + FailureOr parseInlineConstraintLambdaExpr(); + FailureOr parseInlineRewriteLambdaExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); @@ -168,6 +252,7 @@ FailureOr parseEraseStmt(); FailureOr parseLetStmt(); FailureOr parseReplaceStmt(); + FailureOr parseReturnStmt(); FailureOr parseRewriteStmt(); //===--------------------------------------------------------------------===// @@ -177,6 +262,10 @@ //===--------------------------------------------------------------------===// // Decls + /// Try to extract a callable from the given AST node. Returns nullptr on + /// failure. + ast::CallableDecl *tryExtractCallableDecl(ast::Node *node); + /// Try to create a pattern decl with the given components, returning the /// Pattern on success. FailureOr @@ -184,12 +273,30 @@ const ParsedPatternMetadata &metadata, ast::CompoundStmt *body); + /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set + /// of results, defined as part of the signature. + ast::Type + createUserConstraintRewriteResultType(ArrayRef results); + + /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl. + template + FailureOr createUserPDLLConstraintOrRewriteDecl( + const ast::Name &name, ArrayRef arguments, + ArrayRef results, ast::Type resultType, + ast::CompoundStmt *body); + /// Try to create a variable decl with the given components, returning the /// Variable on success. FailureOr createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, ArrayRef constraints); + /// Create a variable for an argument or result defined as part of the + /// signature of a UserConstraintDecl/UserRewriteDecl. + FailureOr + createArgOrResultVariableDecl(StringRef name, SMRange loc, + const ast::ConstraintRef &constraint); + /// Validate the constraints used to constraint a variable decl. /// `inferredType` is the type of the variable inferred by the constraints /// within the list, and is updated to the most refined type as determined by @@ -201,23 +308,26 @@ /// Validate a single reference to a constraint. `inferredType` contains the /// currently inferred variabled type and is refined within the type defined /// by the constraint. Returns success if the constraint is valid, failure - /// otherwise. + /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user + /// defined constraints) may be used with the variable. LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, - ast::Type &inferredType); + ast::Type &inferredType, + bool allowNonCoreConstraints = true); LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); //===--------------------------------------------------------------------===// // Exprs - FailureOr createDeclRefExpr(SMRange loc, - ast::Decl *decl); + FailureOr + createCallExpr(SMRange loc, ast::Expr *parentExpr, + MutableArrayRef arguments); + FailureOr createDeclRefExpr(SMRange loc, ast::Decl *decl); FailureOr createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, ArrayRef constraints); FailureOr - createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, - SMRange loc); + createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); /// Validate the member access `name` into the given parent expression. On /// success, this also returns the type of the member accessed. @@ -231,12 +341,10 @@ LogicalResult validateOperationOperands(SMRange loc, Optional name, MutableArrayRef operands); - LogicalResult validateOperationResults(SMRange loc, - Optional name, + LogicalResult validateOperationResults(SMRange loc, Optional name, MutableArrayRef results); LogicalResult - validateOperationOperandsOrResults(SMRange loc, - Optional name, + validateOperationOperandsOrResults(SMRange loc, Optional name, MutableArrayRef values, ast::Type singleTy, ast::Type rangeTy); FailureOr createTupleExpr(SMRange loc, @@ -246,8 +354,7 @@ //===--------------------------------------------------------------------===// // Stmts - FailureOr createEraseStmt(SMRange loc, - ast::Expr *rootOp); + FailureOr createEraseStmt(SMRange loc, ast::Expr *rootOp); FailureOr createReplaceStmt(SMRange loc, ast::Expr *rootOp, MutableArrayRef replValues); @@ -304,8 +411,8 @@ LogicalResult emitError(const Twine &msg) { return emitError(curToken.getLoc(), msg); } - LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, - SMRange noteLoc, const Twine ¬e) { + LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc, + const Twine ¬e) { lexer.emitErrorAndNote(loc, msg, noteLoc, note); return failure(); } @@ -333,6 +440,9 @@ /// Cached types to simplify verification and expression creation. ast::Type valueTy, valueRangeTy; ast::Type typeTy, typeRangeTy; + + /// A counter used when naming anonymous constraints and rewrites. + unsigned anonymousDeclNameCounter = 0; }; } // namespace @@ -506,9 +616,15 @@ FailureOr Parser::parseTopLevelDecl() { FailureOr decl; switch (curToken.getKind()) { + case Token::kw_Constraint: + decl = parseUserConstraintDecl(); + break; case Token::kw_Pattern: decl = parsePatternDecl(); break; + case Token::kw_Rewrite: + decl = parseUserRewriteDecl(); + break; default: return emitError("expected top-level declaration, such as a `Pattern`"); } @@ -570,6 +686,363 @@ return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement); } +FailureOr Parser::parseArgumentDecl() { + // Ensure that the argument is named. + if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword()) + return emitError("expected identifier argument name"); + + // Parse the argument similarly to a normal variable. + StringRef name = curToken.getSpelling(); + SMRange nameLoc = curToken.getLoc(); + consumeToken(); + + if (failed( + parseToken(Token::colon, "expected `:` before argument constraint"))) + return failure(); + + FailureOr cst = parseArgOrResultConstraint(); + if (failed(cst)) + return failure(); + + return createArgOrResultVariableDecl(name, nameLoc, *cst); +} + +FailureOr Parser::parseResultDecl(unsigned resultNum) { + // Check to see if this result is named. + if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { + // Check to see if this name actually refers to a Constraint. + ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); + if (isa_and_nonnull(existingDecl)) { + // If yes, and this is a Rewrite, give a nice error message as non-Core + // constraints are not supported on Rewrite results. + if (parserContext == ParserContext::Rewrite) { + return emitError( + "`Rewrite` results are only permitted to use core constraints, " + "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); + } + + // Otherwise, parse this as an unnamed result variable. + } else { + // If it wasn't a constraint, parse the result similarly to a variable. If + // there is already an existing decl, we will emit an error when defining + // this variable later. + StringRef name = curToken.getSpelling(); + SMRange nameLoc = curToken.getLoc(); + consumeToken(); + + if (failed(parseToken(Token::colon, + "expected `:` before result constraint"))) + return failure(); + + FailureOr cst = parseArgOrResultConstraint(); + if (failed(cst)) + return failure(); + + return createArgOrResultVariableDecl(name, nameLoc, *cst); + } + } + + // If it isn't named, we parse the constraint directly and create an unnamed + // result variable. + FailureOr cst = parseArgOrResultConstraint(); + if (failed(cst)) + return failure(); + + return createArgOrResultVariableDecl("", cst->referenceLoc, *cst); +} + +FailureOr +Parser::parseUserConstraintDecl(bool isInline) { + // Constraints and rewrites have very similar formats, dispatch to a shared + // interface for parsing. + return parseUserConstraintOrRewriteDecl( + [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); }, + ParserContext::Constraint, "constraint", isInline); +} + +FailureOr Parser::parseInlineUserConstraintDecl() { + FailureOr decl = + parseUserConstraintDecl(/*isInline=*/true); + if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) + return failure(); + + curDeclScope->add(*decl); + return decl; +} + +FailureOr Parser::parseUserPDLLConstraintDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, ast::DeclScope *argumentScope, + ArrayRef results, ast::Type resultType) { + // Push the argument scope back onto the list, so that the body can + // reference arguments. + pushDeclScope(argumentScope); + + // Parse the body of the constraint. The body is either defined as a compound + // block, i.e. `{ ... }`, or a lambda body, i.e. `=> `. + ast::CompoundStmt *body; + if (curToken.is(Token::equal_arrow)) { + FailureOr bodyResult = parseLambdaBody( + [&](ast::Stmt *&stmt) -> LogicalResult { + ast::Expr *stmtExpr = dyn_cast(stmt); + if (!stmtExpr) { + return emitError(stmt->getLoc(), + "expected `Constraint` lambda body to contain a " + "single expression"); + } + stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr); + return success(); + }, + /*expectTerminalSemicolon=*/!isInline); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + } else { + FailureOr bodyResult = parseCompoundStmt(); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + + // Verify the structure of the body. + auto bodyIt = body->begin(), bodyE = body->end(); + for (; bodyIt != bodyE; ++bodyIt) + if (isa(*bodyIt)) + break; + if (failed(validateUserConstraintOrRewriteReturn( + "Constraint", body, bodyIt, bodyE, results, resultType))) + return failure(); + } + popDeclScope(); + + return createUserPDLLConstraintOrRewriteDecl( + name, arguments, results, resultType, body); +} + +FailureOr Parser::parseUserRewriteDecl(bool isInline) { + // Constraints and rewrites have very similar formats, dispatch to a shared + // interface for parsing. + return parseUserConstraintOrRewriteDecl( + [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); }, + ParserContext::Rewrite, "rewrite", isInline); +} + +FailureOr Parser::parseInlineUserRewriteDecl() { + FailureOr decl = + parseUserRewriteDecl(/*isInline=*/true); + if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName()))) + return failure(); + + curDeclScope->add(*decl); + return decl; +} + +FailureOr Parser::parseUserPDLLRewriteDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, ast::DeclScope *argumentScope, + ArrayRef results, ast::Type resultType) { + // Push the argument scope back onto the list, so that the body can + // reference arguments. + curDeclScope = argumentScope; + ast::CompoundStmt *body; + if (curToken.is(Token::equal_arrow)) { + FailureOr bodyResult = parseLambdaBody( + [&](ast::Stmt *&statement) -> LogicalResult { + if (isa(statement)) + return success(); + + ast::Expr *statementExpr = dyn_cast(statement); + if (!statementExpr) { + return emitError( + statement->getLoc(), + "expected `Rewrite` lambda body to contain a single expression " + "or an operation rewrite statement; such as `erase`, " + "`replace`, or `rewrite`"); + } + statement = + ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr); + return success(); + }, + /*expectTerminalSemicolon=*/!isInline); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + } else { + FailureOr bodyResult = parseCompoundStmt(); + if (failed(bodyResult)) + return failure(); + body = *bodyResult; + } + popDeclScope(); + + // Verify the structure of the body. + auto bodyIt = body->begin(), bodyE = body->end(); + for (; bodyIt != bodyE; ++bodyIt) + if (isa(*bodyIt)) + break; + if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt, + bodyE, results, resultType))) + return failure(); + return createUserPDLLConstraintOrRewriteDecl( + name, arguments, results, resultType, body); +} + +template +FailureOr Parser::parseUserConstraintOrRewriteDecl( + ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext, + StringRef anonymousNamePrefix, bool isInline) { + SMRange loc = curToken.getLoc(); + consumeToken(); + llvm::SaveAndRestore saveCtx(parserContext, declContext); + + // Parse the name of the decl. + const ast::Name *name = nullptr; + if (curToken.isNot(Token::identifier)) { + // Only inline decls can be un-named. Inline decls are similar to "lambdas" + // in C++, so being unnamed is fine. + if (!isInline) + return emitError("expected identifier name"); + + // Create a unique anonymous name to use, as the name for this decl is not + // important. + std::string anonName = + llvm::formatv("", anonymousNamePrefix, + anonymousDeclNameCounter++) + .str(); + name = &ast::Name::create(ctx, anonName, loc); + } else { + // If a name was provided, we can use it directly. + name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc()); + consumeToken(Token::identifier); + } + + // Parse the functional signature of the decl. + SmallVector arguments, results; + ast::DeclScope *argumentScope; + ast::Type resultType; + if (failed(parseUserConstraintOrRewriteSignature(arguments, results, + argumentScope, resultType))) + return failure(); + + // Check to see which type of constraint this is. If the constraint contains a + // compound body, this is a PDLL decl. + if (curToken.isAny(Token::l_brace, Token::equal_arrow)) + return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results, + resultType); + + // Otherwise, this is a native decl. + return parseUserNativeConstraintOrRewriteDecl(*name, isInline, arguments, + results, resultType); +} + +template +FailureOr Parser::parseUserNativeConstraintOrRewriteDecl( + const ast::Name &name, bool isInline, + ArrayRef arguments, + ArrayRef results, ast::Type resultType) { + // If followed by a string, the native code body has also been specified. + std::string codeStrStorage; + Optional optCodeStr; + if (curToken.isString()) { + codeStrStorage = curToken.getStringValue(); + optCodeStr = codeStrStorage; + consumeToken(); + } else if (isInline) { + return emitError(name.getLoc(), + "external declarations must be declared in global scope"); + } + if (failed(parseToken(Token::semicolon, + "expected `;` after native declaration"))) + return failure(); + return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); +} + +LogicalResult Parser::parseUserConstraintOrRewriteSignature( + SmallVectorImpl &arguments, + SmallVectorImpl &results, + ast::DeclScope *&argumentScope, ast::Type &resultType) { + // Parse the argument list of the decl. + if (failed(parseToken(Token::l_paren, "expected `(` to start argument list"))) + return failure(); + + argumentScope = pushDeclScope(); + if (curToken.isNot(Token::r_paren)) { + do { + FailureOr argument = parseArgumentDecl(); + if (failed(argument)) + return failure(); + arguments.emplace_back(*argument); + } while (consumeIf(Token::comma)); + } + popDeclScope(); + if (failed(parseToken(Token::r_paren, "expected `)` to end argument list"))) + return failure(); + + // Parse the results of the decl. + pushDeclScope(); + if (consumeIf(Token::arrow)) { + auto parseResultFn = [&]() -> LogicalResult { + FailureOr result = parseResultDecl(results.size()); + if (failed(result)) + return failure(); + results.emplace_back(*result); + return success(); + }; + + // Check for a list of results. + if (consumeIf(Token::l_paren)) { + do { + if (failed(parseResultFn())) + return failure(); + } while (consumeIf(Token::comma)); + if (failed(parseToken(Token::r_paren, "expected `)` to end result list"))) + return failure(); + + // Otherwise, there is only one result. + } else if (failed(parseResultFn())) { + return failure(); + } + } + popDeclScope(); + + // Compute the result type of the decl. + resultType = createUserConstraintRewriteResultType(results); + + // Verify that results are only named if there are more than one. + if (results.size() == 1 && !results.front()->getName().getName().empty()) { + return emitError( + results.front()->getLoc(), + "cannot create a single-element tuple with an element label"); + } + return success(); +} + +LogicalResult Parser::validateUserConstraintOrRewriteReturn( + StringRef declType, ast::CompoundStmt *body, + ArrayRef::iterator bodyIt, + ArrayRef::iterator bodyE, + ArrayRef results, ast::Type &resultType) { + // Handle if a `return` was provided. + if (bodyIt != bodyE) { + // Emit an error if we have trailing statements after the return. + if (std::next(bodyIt) != bodyE) { + return emitError( + (*std::next(bodyIt))->getLoc(), + llvm::formatv("`return` terminated the `{0}` body, but found " + "trailing statements afterwards", + declType)); + } + + // Otherwise if a return wasn't provided, check that no results are + // expected. + } else if (!results.empty()) { + return emitError( + {body->getLoc().End, body->getLoc().End}, + llvm::formatv("missing return in a `{0}` expected to return `{1}`", + declType, resultType)); + } + return success(); +} + FailureOr Parser::parsePatternLambdaBody() { return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult { if (isa(statement)) @@ -619,6 +1092,11 @@ // Verify the body of the pattern. auto bodyIt = body->begin(), bodyE = body->end(); for (; bodyIt != bodyE; ++bodyIt) { + if (isa(*bodyIt)) { + return emitError((*bodyIt)->getLoc(), + "`return` statements are only permitted within a " + "`Constraint` or `Rewrite` body"); + } // Break when we've found the rewrite statement. if (isa(*bodyIt)) break; @@ -719,8 +1197,8 @@ } FailureOr -Parser::defineVariableDecl(StringRef name, SMRange nameLoc, - ast::Type type, ast::Expr *initExpr, +Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, + ast::Expr *initExpr, ArrayRef constraints) { assert(curDeclScope && "defining variable outside of decl scope"); const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc); @@ -741,8 +1219,7 @@ } FailureOr -Parser::defineVariableDecl(StringRef name, SMRange nameLoc, - ast::Type type, +Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type, ArrayRef constraints) { return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr, constraints); @@ -752,8 +1229,8 @@ SmallVectorImpl &constraints) { Optional typeConstraint; auto parseSingleConstraint = [&] { - FailureOr constraint = - parseConstraint(typeConstraint, constraints); + FailureOr constraint = parseConstraint( + typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); if (failed(constraint)) return failure(); constraints.push_back(*constraint); @@ -773,8 +1250,15 @@ FailureOr Parser::parseConstraint(Optional &typeConstraint, - ArrayRef existingConstraints) { + ArrayRef existingConstraints, + bool allowInlineTypeConstraints) { auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { + if (!allowInlineTypeConstraints) { + return emitError( + curToken.getLoc(), + "inline `Attr`, `Value`, and `ValueRange` type constraints are not " + "permitted on arguments or results"); + } if (typeConstraint) return emitErrorAndNote( curToken.getLoc(), @@ -842,6 +1326,14 @@ return ast::ConstraintRef( ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc); } + + case Token::kw_Constraint: { + // Handle an inline constraint. + FailureOr decl = parseInlineUserConstraintDecl(); + if (failed(decl)) + return failure(); + return ast::ConstraintRef(*decl, loc); + } case Token::identifier: { StringRef constraintName = curToken.getSpelling(); consumeToken(Token::identifier); @@ -867,6 +1359,12 @@ return emitError(loc, "expected identifier constraint"); } +FailureOr Parser::parseArgOrResultConstraint() { + Optional typeConstraint; + return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, + /*allowInlineTypeConstraints=*/false); +} + //===----------------------------------------------------------------------===// // Exprs @@ -880,12 +1378,18 @@ case Token::kw_attr: lhsExpr = parseAttributeExpr(); break; + case Token::kw_Constraint: + lhsExpr = parseInlineConstraintLambdaExpr(); + break; case Token::identifier: lhsExpr = parseIdentifierExpr(); break; case Token::kw_op: lhsExpr = parseOperationExpr(); break; + case Token::kw_Rewrite: + lhsExpr = parseInlineRewriteLambdaExpr(); + break; case Token::kw_type: lhsExpr = parseTypeExpr(); break; @@ -904,6 +1408,9 @@ case Token::dot: lhsExpr = parseMemberAccessExpr(*lhsExpr); break; + case Token::l_paren: + lhsExpr = parseCallExpr(*lhsExpr); + break; default: return lhsExpr; } @@ -934,8 +1441,28 @@ return ast::AttributeExpr::create(ctx, loc, attrExpr); } -FailureOr Parser::parseDeclRefExpr(StringRef name, - SMRange loc) { +FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { + SMRange loc = curToken.getLoc(); + consumeToken(Token::l_paren); + + // Parse the arguments of the call. + SmallVector arguments; + if (curToken.isNot(Token::r_paren)) { + do { + FailureOr argument = parseExpr(); + if (failed(argument)) + return failure(); + arguments.push_back(*argument); + } while (consumeIf(Token::comma)); + } + loc.End = curToken.getEndLoc(); + if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) + return failure(); + + return createCallExpr(loc, parentExpr, arguments); +} + +FailureOr Parser::parseDeclRefExpr(StringRef name, SMRange loc) { ast::Decl *decl = curDeclScope->lookup(name); if (!decl) return emitError(loc, "undefined reference to `" + name + "`"); @@ -963,6 +1490,24 @@ return parseDeclRefExpr(name, nameLoc); } +FailureOr Parser::parseInlineConstraintLambdaExpr() { + FailureOr decl = parseInlineUserConstraintDecl(); + if (failed(decl)) + return failure(); + + return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, + ast::ConstraintType::get(ctx)); +} + +FailureOr Parser::parseInlineRewriteLambdaExpr() { + FailureOr decl = parseInlineUserRewriteDecl(); + if (failed(decl)) + return failure(); + + return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl, + ast::RewriteType::get(ctx)); +} + FailureOr Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { SMRange loc = curToken.getLoc(); consumeToken(Token::dot); @@ -1202,6 +1747,9 @@ case Token::kw_replace: stmt = parseReplaceStmt(); break; + case Token::kw_return: + stmt = parseReturnStmt(); + break; case Token::kw_rewrite: stmt = parseRewriteStmt(); break; @@ -1239,6 +1787,8 @@ } FailureOr Parser::parseEraseStmt() { + if (parserContext == ParserContext::Constraint) + return emitError("`erase` cannot be used within a Constraint"); SMRange loc = curToken.getLoc(); consumeToken(Token::kw_erase); @@ -1311,6 +1861,8 @@ } FailureOr Parser::parseReplaceStmt() { + if (parserContext == ParserContext::Constraint) + return emitError("`replace` cannot be used within a Constraint"); SMRange loc = curToken.getLoc(); consumeToken(Token::kw_replace); @@ -1356,7 +1908,21 @@ return createReplaceStmt(loc, *rootOp, replValues); } +FailureOr Parser::parseReturnStmt() { + SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_return); + + // Parse the result value. + FailureOr resultExpr = parseExpr(); + if (failed(resultExpr)) + return failure(); + + return ast::ReturnStmt::create(ctx, loc, *resultExpr); +} + FailureOr Parser::parseRewriteStmt() { + if (parserContext == ParserContext::Constraint) + return emitError("`rewrite` cannot be used within a Constraint"); SMRange loc = curToken.getLoc(); consumeToken(Token::kw_rewrite); @@ -1379,6 +1945,15 @@ if (failed(rewriteBody)) return failure(); + // Verify the rewrite body. + for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) { + if (isa(stmt)) { + return emitError(stmt->getLoc(), + "`return` statements are only permitted within a " + "`Constraint` or `Rewrite` body"); + } + } + return createRewriteStmt(loc, *rootOp, *rewriteBody); } @@ -1389,6 +1964,13 @@ //===----------------------------------------------------------------------===// // Decls +ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) { + // Unwrap reference expressions. + if (auto *init = dyn_cast(node)) + node = init->getDecl(); + return dyn_cast(node); +} + FailureOr Parser::createPatternDecl(SMRange loc, const ast::Name *name, const ParsedPatternMetadata &metadata, @@ -1397,9 +1979,47 @@ metadata.hasBoundedRecursion, body); } +ast::Type Parser::createUserConstraintRewriteResultType( + ArrayRef results) { + // Single result decls use the type of the single result. + if (results.size() == 1) + return results[0]->getType(); + + // Multiple results use a tuple type, with the types and names grabbed from + // the result variable decls. + auto resultTypes = llvm::map_range( + results, [&](const auto *result) { return result->getType(); }); + auto resultNames = llvm::map_range( + results, [&](const auto *result) { return result->getName().getName(); }); + return ast::TupleType::get(ctx, llvm::to_vector(resultTypes), + llvm::to_vector(resultNames)); +} + +template +FailureOr Parser::createUserPDLLConstraintOrRewriteDecl( + const ast::Name &name, ArrayRef arguments, + ArrayRef results, ast::Type resultType, + ast::CompoundStmt *body) { + if (!body->getChildren().empty()) { + if (auto *retStmt = dyn_cast(body->getChildren().back())) { + ast::Expr *resultExpr = retStmt->getResultExpr(); + + // Process the result of the decl. If no explicit signature results + // were provided, check for return type inference. Otherwise, check that + // the return expression can be converted to the expected type. + if (results.empty()) + resultType = resultExpr->getType(); + else if (failed(convertExpressionTo(resultExpr, resultType))) + return failure(); + else + retStmt->setResultExpr(resultExpr); + } + } + return T::createPDLL(ctx, name, arguments, results, body, resultType); +} + FailureOr -Parser::createVariableDecl(StringRef name, SMRange loc, - ast::Expr *initializer, +Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer, ArrayRef constraints) { // The type of the variable, which is expected to be inferred by either a // constraint or an initializer expression. @@ -1426,6 +2046,12 @@ "list or the initializer"); } + // Constraint types cannot be used when defining variables. + if (type.isa()) { + return emitError( + loc, llvm::formatv("unable to define variable of `{0}` type", type)); + } + // Try to define a variable with the given name. FailureOr varDecl = defineVariableDecl(name, loc, type, initializer, constraints); @@ -1435,6 +2061,18 @@ return *varDecl; } +FailureOr +Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, + const ast::ConstraintRef &constraint) { + // Constraint arguments may apply more complex constraints via the arguments. + bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; + ast::Type argType; + if (failed(validateVariableConstraint(constraint, argType, + allowNonCoreConstraints))) + return failure(); + return defineVariableDecl(name, loc, argType, constraint); +} + LogicalResult Parser::validateVariableConstraints(ArrayRef constraints, ast::Type &inferredType) { @@ -1445,7 +2083,8 @@ } LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, - ast::Type &inferredType) { + ast::Type &inferredType, + bool allowNonCoreConstraints) { ast::Type constraintType; if (const auto *cst = dyn_cast(ref.constraint)) { if (const ast::Expr *typeExpr = cst->getTypeExpr()) { @@ -1474,6 +2113,25 @@ return failure(); } constraintType = valueRangeTy; + } else if (const auto *cst = + dyn_cast(ref.constraint)) { + if (!allowNonCoreConstraints) { + return emitError(ref.referenceLoc, + "`Rewrite` arguments and results are only permitted to " + "use core constraints, such as `Attr`, `Op`, `Type`, " + "`TypeRange`, `Value`, `ValueRange`"); + } + + ArrayRef inputs = cst->getInputs(); + if (inputs.size() != 1) { + return emitErrorAndNote(ref.referenceLoc, + "`Constraint`s applied via a variable constraint " + "list must take a single input, but got " + + Twine(inputs.size()), + cst->getLoc(), + "see definition of constraint here"); + } + constraintType = inputs.front()->getType(); } else { llvm_unreachable("unknown constraint type"); } @@ -1515,11 +2173,66 @@ //===----------------------------------------------------------------------===// // Exprs +FailureOr +Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, + MutableArrayRef arguments) { + ast::Type parentType = parentExpr->getType(); + + ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); + if (!callableDecl) { + return emitError(loc, + llvm::formatv("expected a reference to a callable " + "`Constraint` or `Rewrite`, but got: `{0}`", + parentType)); + } + if (parserContext == ParserContext::Rewrite) { + if (isa(callableDecl)) + return emitError( + loc, "unable to invoke `Constraint` within a rewrite section"); + } else if (isa(callableDecl)) { + return emitError(loc, "unable to invoke `Rewrite` within a match section"); + } + + // Verify the arguments of the call. + /// Handle size mismatch. + ArrayRef callArgs = callableDecl->getInputs(); + if (callArgs.size() != arguments.size()) { + return emitErrorAndNote( + loc, + llvm::formatv("invalid number of arguments for {0} call; expected " + "{1}, but got {2}", + callableDecl->getCallableType(), callArgs.size(), + arguments.size()), + callableDecl->getLoc(), + llvm::formatv("see the definition of {0} here", + callableDecl->getName()->getName())); + } + + /// Handle argument type mismatch. + auto attachDiagFn = [&](ast::Diagnostic &diag) { + diag.attachNote(llvm::formatv("see the definition of `{0}` here", + callableDecl->getName()->getName()), + callableDecl->getLoc()); + }; + for (auto it : llvm::zip(callArgs, arguments)) { + if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(), + attachDiagFn))) + return failure(); + } + + return ast::CallExpr::create(ctx, loc, parentExpr, arguments, + callableDecl->getResultType()); +} + FailureOr Parser::createDeclRefExpr(SMRange loc, ast::Decl *decl) { // Check the type of decl being referenced. ast::Type declType; - if (auto *varDecl = dyn_cast(decl)) + if (isa(decl)) + declType = ast::ConstraintType::get(ctx); + else if (isa(decl)) + declType = ast::RewriteType::get(ctx); + else if (auto *varDecl = dyn_cast(decl)) declType = varDecl->getType(); else return emitError(loc, "invalid reference to `" + @@ -1529,8 +2242,7 @@ } FailureOr -Parser::createInlineVariableExpr(ast::Type type, StringRef name, - SMRange loc, +Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, ArrayRef constraints) { FailureOr decl = defineVariableDecl(name, loc, type, constraints); @@ -1551,8 +2263,7 @@ } FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, - StringRef name, - SMRange loc) { + StringRef name, SMRange loc) { ast::Type parentType = parentExpr->getType(); if (parentType.isa()) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) @@ -1622,9 +2333,8 @@ } LogicalResult Parser::validateOperationOperandsOrResults( - SMRange loc, Optional name, - MutableArrayRef values, ast::Type singleTy, - ast::Type rangeTy) { + SMRange loc, Optional name, MutableArrayRef values, + ast::Type singleTy, ast::Type rangeTy) { // All operation types accept a single range parameter. if (values.size() == 1) { if (failed(convertExpressionTo(values[0], rangeTy))) @@ -1665,7 +2375,7 @@ ArrayRef elementNames) { for (const ast::Expr *element : elements) { ast::Type eleTy = element->getType(); - if (eleTy.isa()) { + if (eleTy.isa()) { return emitError( element->getLoc(), llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll @@ -0,0 +1,160 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +//===----------------------------------------------------------------------===// +// Constraint Structure +//===----------------------------------------------------------------------===// + +// CHECK: expected identifier name +Constraint {} + +// ----- + +// CHECK: :6:12: error: `Foo` has already been defined +// CHECK: :5:12: note: see previous definition here +Constraint Foo() { op<>; } +Constraint Foo() { op<>; } + +// ----- + +Constraint Foo() { + // CHECK: `erase` cannot be used within a Constraint + erase op<>; +} + +// ----- + +Constraint Foo() { + // CHECK: `replace` cannot be used within a Constraint + replace; +} + +// ----- + +Constraint Foo() { + // CHECK: `rewrite` cannot be used within a Constraint + rewrite; +} + +// ----- + +Constraint Foo() -> Value { + // CHECK: `return` terminated the `Constraint` body, but found trailing statements afterwards + return _: Value; + return _: Value; +} + +// ----- + +// CHECK: missing return in a `Constraint` expected to return `Value` +Constraint Foo() -> Value { + let value: Value; +} + +// ----- + +// CHECK: expected `Constraint` lambda body to contain a single expression +Constraint Foo() -> Value => let foo: Value; + +// ----- + +// CHECK: unable to convert expression of type `Op` to the expected type of `Attr` +Constraint Foo() -> Attr => op<>; + +// ----- + +Rewrite SomeRewrite(); + +// CHECK: unable to invoke `Rewrite` within a match section +Constraint Foo() { + SomeRewrite(); +} + +// ----- + +Constraint Foo() { + Constraint Foo() {}; +} + +// ----- + +//===----------------------------------------------------------------------===// +// Arguments +//===----------------------------------------------------------------------===// + +// CHECK: expected `(` to start argument list +Constraint Foo {} + +// ----- + +// CHECK: expected identifier argument name +Constraint Foo(10{} + +// ----- + +// CHECK: expected `:` before argument constraint +Constraint Foo(arg{} + +// ----- + +// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results +Constraint Foo(arg: Value){} + +// ----- + +// CHECK: expected `)` to end argument list +Constraint Foo(arg: Value{} + +// ----- + +//===----------------------------------------------------------------------===// +// Results +//===----------------------------------------------------------------------===// + +// CHECK: expected identifier constraint +Constraint Foo() -> {} + +// ----- + +// CHECK: cannot create a single-element tuple with an element label +Constraint Foo() -> result: Value; + +// ----- + +// CHECK: cannot create a single-element tuple with an element label +Constraint Foo() -> (result: Value); + +// ----- + +// CHECK: expected identifier constraint +Constraint Foo() -> (); + +// ----- + +// CHECK: expected `:` before result constraint +Constraint Foo() -> (result{}; + +// ----- + +// CHECK: expected `)` to end result list +Constraint Foo() -> (Op{}; + +// ----- + +// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results +Constraint Foo() -> Value){} + +// ----- + +//===----------------------------------------------------------------------===// +// Native Constraints +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: external declarations must be declared in global scope + Constraint ExternalConstraint(); +} + +// ----- + +// CHECK: expected `;` after native declaration +Constraint Foo() [{}] diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/constraint.pdll @@ -0,0 +1,74 @@ +// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint Foo(); + +// ----- + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> Code< /* Native Code */ > +Constraint Foo() [{ /* Native Code */ }]; + +// ----- + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Results` +// CHECK: `-VariableDecl {{.*}} Name<> Type +// CHECK: `-CompoundStmt {{.*}} +// CHECK: `-ReturnStmt {{.*}} +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-VariableDecl {{.*}} Name Type +Constraint Foo(arg: Value) -> Value => arg; + +// ----- + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `Results` +// CHECK: |-VariableDecl {{.*}} Name Type +// CHECK: | `Constraints` +// CHECK: | `-ValueConstraintDecl {{.*}} +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-AttrConstraintDecl {{.*}} +// CHECK: `-CompoundStmt {{.*}} +// CHECK: `-ReturnStmt {{.*}} +// CHECK: `-TupleExpr {{.*}} Type> +// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: | `-TupleExpr {{.*}} Type> +// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type +// CHECK: `-TupleExpr {{.*}} Type> +Constraint Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">); + +// ----- + +// CHECK: Module +// CHECK: |-UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `Results` +// CHECK: `-VariableDecl {{.*}} Name<> Type +// CHECK: `Constraints` +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint Bar(input: Value); + +Constraint Foo(arg: Bar) -> Bar => arg; + +// ----- + +// Test that anonymous constraints are uniquely named. + +// CHECK: Module +// CHECK: UserConstraintDecl {{.*}} Name<> ResultType> +// CHECK: UserConstraintDecl {{.*}} Name<> ResultType +Constraint Outer() { + Constraint() {}; + Constraint() => attr<"10">; +} diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -43,6 +43,45 @@ // ----- +//===----------------------------------------------------------------------===// +// Call Expr +//===----------------------------------------------------------------------===// + +Constraint foo(value: Value); + +Pattern { + // CHECK: expected `)` after argument list + foo(_: Value{}; +} + +// ----- + +Pattern { + // CHECK: expected a reference to a callable `Constraint` or `Rewrite`, but got: `Op` + let foo: Op; + foo(); +} + +// ----- + +Constraint Foo(); + +Pattern { + // CHECK: invalid number of arguments for constraint call; expected 0, but got 1 + Foo(_: Value); +} + +// ----- + +Constraint Foo(arg: Value); + +Pattern { + // CHECK: unable to convert expression of type `Attr` to the expected type of `Value` + Foo(attr<"i32">); +} + +// ----- + //===----------------------------------------------------------------------===// // Member Access Expr //===----------------------------------------------------------------------===// @@ -105,6 +144,26 @@ // ----- +Constraint Foo(); + +Pattern { + // CHECK: unable to build a tuple with `Constraint` element + let tuple = (Foo); + erase op<>; +} + +// ----- + +Rewrite Foo(); + +Pattern { + // CHECK: unable to build a tuple with `Rewrite` element + let tuple = (Foo); + erase op<>; +} + +// ----- + Pattern { // CHECK: expected expression let tuple = (10 = _: Value); diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -14,6 +14,42 @@ // ----- +//===----------------------------------------------------------------------===// +// CallExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: |-UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `-CallExpr {{.*}} Type> +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint MakeRootOp() => op; + +Pattern { + erase MakeRootOp(); +} + +// ----- + +// CHECK: Module +// CHECK: |-UserRewriteDecl {{.*}} Name ResultType> +// CHECK: `-PatternDecl {{.*}} +// CHECK: `-CallExpr {{.*}} Type> +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType> +// CHECK: `Arguments` +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name Type> +Rewrite CreateNewOp(inputs: ValueRange) => op(inputs); + +Pattern { + let inputOp = op; + replace op(inputOp) with CreateNewOp(inputOp); +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll --- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll @@ -1,4 +1,4 @@ -// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s // CHECK: expected `{` or `=>` to start pattern body Pattern } @@ -12,6 +12,13 @@ // ----- +// CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body +Pattern { + return _: Value; +} + +// ----- + // CHECK: expected Pattern body to terminate with an operation rewrite statement Pattern { let value: Value; @@ -32,6 +39,15 @@ // ----- +Rewrite SomeRewrite(); + +// CHECK: unable to invoke `Rewrite` within a match section +Pattern { + SomeRewrite(); +} + +// ----- + //===----------------------------------------------------------------------===// // Metadata //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll @@ -0,0 +1,161 @@ +// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s + +//===----------------------------------------------------------------------===// +// Rewrite Structure +//===----------------------------------------------------------------------===// + +// CHECK: expected identifier name +Rewrite {} + +// ----- + +// CHECK: :6:9: error: `Foo` has already been defined +// CHECK: :5:9: note: see previous definition here +Rewrite Foo(); +Rewrite Foo(); + +// ----- + +Rewrite Foo() -> Value { + // CHECK: `return` terminated the `Rewrite` body, but found trailing statements afterwards + return _: Value; + return _: Value; +} + +// ----- + +// CHECK: missing return in a `Rewrite` expected to return `Value` +Rewrite Foo() -> Value { + let value: Value; +} + +// ----- + +// CHECK: missing return in a `Rewrite` expected to return `Value` +Rewrite Foo() -> Value => erase op; + +// ----- + +// CHECK: unable to convert expression of type `Op` to the expected type of `Attr` +Rewrite Foo() -> Attr => op; + +// ----- + +// CHECK: expected `Rewrite` lambda body to contain a single expression or an operation rewrite statement; such as `erase`, `replace`, or `rewrite` +Rewrite Foo() => let foo = op; + +// ----- + +Constraint ValueConstraint(value: Value); + +// CHECK: unable to invoke `Constraint` within a rewrite section +Rewrite Foo(value: Value) { + ValueConstraint(value); +} + +// ----- + +Rewrite Bar(); + +// CHECK: `Bar` has already been defined +Rewrite Foo() { + Rewrite Bar() {}; +} + +// ----- + +//===----------------------------------------------------------------------===// +// Arguments +//===----------------------------------------------------------------------===// + +// CHECK: expected `(` to start argument list +Rewrite Foo {} + +// ----- + +// CHECK: expected identifier argument name +Rewrite Foo(10{} + +// ----- + +// CHECK: expected `:` before argument constraint +Rewrite Foo(arg{} + +// ----- + +// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results +Rewrite Foo(arg: Value){} + +// ----- + +Constraint ValueConstraint(value: Value); + +// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange` +Rewrite Foo(arg: ValueConstraint); + +// ----- + +// CHECK: expected `)` to end argument list +Rewrite Foo(arg: Value{} + +// ----- + +//===----------------------------------------------------------------------===// +// Results +//===----------------------------------------------------------------------===// + +// CHECK: expected identifier constraint +Rewrite Foo() -> {} + +// ----- + +// CHECK: cannot create a single-element tuple with an element label +Rewrite Foo() -> result: Value; + +// ----- + +// CHECK: cannot create a single-element tuple with an element label +Rewrite Foo() -> (result: Value); + +// ----- + +// CHECK: expected identifier constraint +Rewrite Foo() -> (); + +// ----- + +// CHECK: expected `:` before result constraint +Rewrite Foo() -> (result{}; + +// ----- + +// CHECK: expected `)` to end result list +Rewrite Foo() -> (Op{}; + +// ----- + +// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results +Rewrite Foo() -> Value){} + +// ----- + +Constraint ValueConstraint(value: Value); + +// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange` +Rewrite Foo() -> ValueConstraint; + +// ----- + +//===----------------------------------------------------------------------===// +// Native Rewrites +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: external declarations must be declared in global scope + Rewrite ExternalConstraint(); +} + +// ----- + +// CHECK: expected `;` after native declaration +Rewrite Foo() [{}] diff --git a/mlir/test/mlir-pdll/Parser/rewrite.pdll b/mlir/test/mlir-pdll/Parser/rewrite.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/rewrite.pdll @@ -0,0 +1,58 @@ +// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s + +// CHECK: Module +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType> +Rewrite Foo(); + +// ----- + +// CHECK: Module +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType> Code< /* Native Code */ > +Rewrite Foo() [{ /* Native Code */ }]; + +// ----- + +// CHECK: Module +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Results` +// CHECK: `-VariableDecl {{.*}} Name<> Type +// CHECK: `-CompoundStmt {{.*}} +// CHECK: `-ReturnStmt {{.*}} +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-VariableDecl {{.*}} Name Type +Rewrite Foo(arg: Op) -> Value => arg; + +// ----- + +// CHECK: Module +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType> +// CHECK: `Results` +// CHECK: |-VariableDecl {{.*}} Name Type +// CHECK: | `Constraints` +// CHECK: | `-ValueConstraintDecl {{.*}} +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-AttrConstraintDecl {{.*}} +// CHECK: `-CompoundStmt {{.*}} +// CHECK: `-ReturnStmt {{.*}} +// CHECK: `-TupleExpr {{.*}} Type> +// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: | `-TupleExpr {{.*}} Type> +// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type +// CHECK: `-TupleExpr {{.*}} Type> +Rewrite Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">); + +// ----- + +// Test that anonymous Rewrites are uniquely named. + +// CHECK: Module +// CHECK: UserRewriteDecl {{.*}} Name<> ResultType> +// CHECK: UserRewriteDecl {{.*}} Name<> ResultType +Rewrite Outer() { + Rewrite() {}; + Rewrite() => attr<"10">; +} diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -223,6 +223,33 @@ // ----- +Constraint Foo(); + +Pattern { + // CHECK: unable to define variable of `Constraint` type + let foo = Foo; +} + +// ----- + +Rewrite Foo(); + +Pattern { + // CHECK: unable to define variable of `Rewrite` type + let foo = Foo; +} + +// ----- + +Constraint MultiConstraint(arg1: Value, arg2: Value); + +Pattern { + // CHECK: `Constraint`s applied via a variable constraint list must take a single input, but got 2 + let foo: MultiConstraint; +} + +// ----- + //===----------------------------------------------------------------------===// // `replace` //===----------------------------------------------------------------------===// @@ -276,6 +303,17 @@ // ----- +//===----------------------------------------------------------------------===// +// `return` +//===----------------------------------------------------------------------===// + +// CHECK: expected `;` after statement +Constraint Foo(arg: Value) -> Value { + return arg +} + +// ----- + //===----------------------------------------------------------------------===// // `rewrite` //===----------------------------------------------------------------------===// @@ -307,3 +345,12 @@ op<>; }; } + +// ----- + +Pattern { + // CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body + rewrite root: Op with { + return root; + }; +}