diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -270,6 +270,29 @@ Type type; }; +//===----------------------------------------------------------------------===// +// AttributeExpr +//===----------------------------------------------------------------------===// + +/// This expression represents a literal MLIR Attribute, and contains the +/// textual assembly format of that attribute. +class AttributeExpr : public Node::NodeBase { +public: + static AttributeExpr *create(Context &ctx, llvm::SMRange loc, + StringRef value); + + /// Get the raw value of this expression. This is the textual assembly format + /// of the MLIR Attribute. + StringRef getValue() const { return value; } + +private: + AttributeExpr(Context &ctx, llvm::SMRange loc, StringRef value) + : Base(loc, AttributeType::get(ctx)), value(value) {} + + /// The value referenced by this expression. + StringRef value; +}; + //===----------------------------------------------------------------------===// // DeclRefExpr //===----------------------------------------------------------------------===// @@ -321,6 +344,28 @@ StringRef memberName; }; +//===----------------------------------------------------------------------===// +// TypeExpr +//===----------------------------------------------------------------------===// + +/// This expression represents a literal MLIR Type, and contains the textual +/// assembly format of that type. +class TypeExpr : public Node::NodeBase { +public: + static TypeExpr *create(Context &ctx, llvm::SMRange loc, StringRef value); + + /// Get the raw value of this expression. This is the textual assembly format + /// of the MLIR Type. + StringRef getValue() const { return value; } + +private: + TypeExpr(Context &ctx, llvm::SMRange loc, StringRef value) + : Base(loc, TypeType::get(ctx)), value(value) {} + + /// The value referenced by this expression. + StringRef value; +}; + //===----------------------------------------------------------------------===// // Decl //===----------------------------------------------------------------------===// @@ -678,7 +723,7 @@ } inline bool Expr::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool OpRewriteStmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -76,8 +76,10 @@ void printImpl(const EraseStmt *stmt); void printImpl(const LetStmt *stmt); + void printImpl(const AttributeExpr *expr); void printImpl(const DeclRefExpr *expr); void printImpl(const MemberAccessExpr *expr); + void printImpl(const TypeExpr *expr); void printImpl(const AttrConstraintDecl *decl); void printImpl(const OpConstraintDecl *decl); @@ -144,7 +146,8 @@ const CompoundStmt, const EraseStmt, const LetStmt, // Expressions. - const DeclRefExpr, const MemberAccessExpr, + const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, + const TypeExpr, // Decls. const AttrConstraintDecl, const OpConstraintDecl, @@ -172,6 +175,10 @@ printChildren(stmt->getVarDecl()); } +void NodePrinter::printImpl(const AttributeExpr *expr) { + os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; +} + void NodePrinter::printImpl(const DeclRefExpr *expr) { os << "DeclRefExpr " << expr << " Type<"; print(expr->getType()); @@ -187,6 +194,10 @@ printChildren(expr->getParentExpr()); } +void NodePrinter::printImpl(const TypeExpr *expr) { + os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; +} + void NodePrinter::printImpl(const AttrConstraintDecl *decl) { os << "AttrConstraintDecl " << decl << "\n"; if (const auto *typeExpr = decl->getTypeExpr()) diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -81,6 +81,16 @@ return new (ctx.getAllocator().Allocate()) EraseStmt(loc, rootOp); } +//===----------------------------------------------------------------------===// +// AttributeExpr +//===----------------------------------------------------------------------===// + +AttributeExpr *AttributeExpr::create(Context &ctx, llvm::SMRange loc, + StringRef value) { + return new (ctx.getAllocator().Allocate()) + AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); +} + //===----------------------------------------------------------------------===// // DeclRefExpr //===----------------------------------------------------------------------===// @@ -102,6 +112,15 @@ loc, parentExpr, memberName.copy(ctx.getAllocator()), type); } +//===----------------------------------------------------------------------===// +// TypeExpr +//===----------------------------------------------------------------------===// + +TypeExpr *TypeExpr::create(Context &ctx, llvm::SMRange loc, StringRef value) { + return new (ctx.getAllocator().Allocate()) + TypeExpr(ctx, loc, copyStringWithNull(ctx, value)); +} + //===----------------------------------------------------------------------===// // AttrConstraintDecl //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -114,11 +114,13 @@ FailureOr parseExpr(); /// Identifier expressions. + FailureOr parseAttributeExpr(); FailureOr parseDeclRefExpr(StringRef name, llvm::SMRange loc); FailureOr parseIdentifierExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(); + FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); //===--------------------------------------------------------------------===// @@ -203,6 +205,12 @@ consumeToken(); } + /// Reset the lexer to the location at the given position. + void resetToken(llvm::SMRange tokLoc) { + lexer.resetPointer(tokLoc.Start.getPointer()); + curToken = lexer.lexToken(); + } + /// Consume the specified token if present and return success. On failure, /// output a diagnostic and return failure. LogicalResult parseToken(Token::Kind kind, const Twine &msg) { @@ -692,9 +700,15 @@ // Parse the LHS expression. FailureOr lhsExpr; switch (curToken.getKind()) { + case Token::kw_attr: + lhsExpr = parseAttributeExpr(); + break; case Token::identifier: lhsExpr = parseIdentifierExpr(); break; + case Token::kw_type: + lhsExpr = parseTypeExpr(); + break; default: return emitError("expected expression"); } @@ -715,6 +729,28 @@ } } +FailureOr Parser::parseAttributeExpr() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_attr); + + // If we aren't followed by a `<`, the `attr` keyword is treated as a normal + // identifier. + if (!consumeIf(Token::less)) { + resetToken(loc); + return parseIdentifierExpr(); + } + + if (!curToken.isString()) + return emitError("expected string literal containing MLIR attribute"); + std::string attrExpr = curToken.getStringValue(); + consumeToken(); + + if (failed( + parseToken(Token::greater, "expected `>` after attribute literal"))) + return failure(); + return ast::AttributeExpr::create(context, loc, attrExpr); +} + FailureOr Parser::parseDeclRefExpr(StringRef name, llvm::SMRange loc) { ast::Decl *decl = curDeclScope->lookup(name); @@ -802,6 +838,27 @@ return opNameDecl; } +FailureOr Parser::parseTypeExpr() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_type); + + // If we aren't followed by a `<`, the `type` keyword is treated as a normal + // identifier. + if (!consumeIf(Token::less)) { + resetToken(loc); + return parseIdentifierExpr(); + } + + if (!curToken.isString()) + return emitError("expected string literal containing MLIR type"); + std::string attrExpr = curToken.getStringValue(); + consumeToken(); + + if (failed(parseToken(Token::greater, "expected `>` after type literal"))) + return failure(); + return ast::TypeExpr::create(context, loc, attrExpr); +} + FailureOr Parser::parseUnderscoreExpr() { StringRef name = curToken.getSpelling(); llvm::SMRange nameLoc = curToken.getLoc(); diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -60,3 +60,39 @@ let root: Op; erase root.unknown_result; } + +// ----- + +//===----------------------------------------------------------------------===// +// `attr` Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected string literal containing MLIR attribute + let foo = attr; +} + +// ----- + +Pattern { + // CHECK: expected `>` after attribute literal + let foo = attr<""<>; +} + +// ----- + +//===----------------------------------------------------------------------===// +// `type` Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected string literal containing MLIR type + let foo = type` after type literal + let foo = type<""; +} \ No newline at end of file diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -0,0 +1,27 @@ +// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s + +//===----------------------------------------------------------------------===// +// AttrExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-AttributeExpr {{.*}} Value<"10: i32"> +Pattern { + let attr = attr<"10: i32">; + + erase _: Op; +} + +// ----- + +//===----------------------------------------------------------------------===// +// TypeExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-TypeExpr {{.*}} Value<"i64"> +Pattern { + let type = type<"i64">; + + erase _: Op; +}