diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -278,6 +278,28 @@ friend class llvm::TrailingObjects; }; +//===----------------------------------------------------------------------===// +// RewriteStmt + +/// This statement represents an operation rewrite that contains a block of +/// nested rewrite commands. This allows for building more complex operation +/// rewrites that span across multiple statements, which may be unconnected. +class RewriteStmt final : public Node::NodeBase { +public: + static RewriteStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + CompoundStmt *rewriteBody); + + /// Return the compound rewrite body. + CompoundStmt *getRewriteBody() const { return rewriteBody; } + +private: + RewriteStmt(llvm::SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody) + : Base(loc, rootOp), rewriteBody(rewriteBody) {} + + /// The body of nested rewriters within this statement. + CompoundStmt *rewriteBody; +}; + //===----------------------------------------------------------------------===// // Expr //===----------------------------------------------------------------------===// @@ -909,7 +931,7 @@ } inline bool OpRewriteStmt::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool Stmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -76,6 +76,7 @@ void printImpl(const EraseStmt *stmt); void printImpl(const LetStmt *stmt); void printImpl(const ReplaceStmt *stmt); + void printImpl(const RewriteStmt *stmt); void printImpl(const AttributeExpr *expr); void printImpl(const DeclRefExpr *expr); @@ -159,6 +160,7 @@ .Case< // Statements. const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, + const RewriteStmt, // Expressions. const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, @@ -197,6 +199,11 @@ printChildren("ReplValues", stmt->getReplExprs()); } +void NodePrinter::printImpl(const RewriteStmt *stmt) { + os << "RewriteStmt " << stmt << "\n"; + printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody()); +} + void NodePrinter::printImpl(const AttributeExpr *expr) { os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -99,6 +99,15 @@ return stmt; } +//===----------------------------------------------------------------------===// +// RewriteStmt + +RewriteStmt *RewriteStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + CompoundStmt *rewriteBody) { + return new (ctx.getAllocator().Allocate()) + RewriteStmt(loc, rootOp, rewriteBody); +} + //===----------------------------------------------------------------------===// // AttributeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -164,6 +164,7 @@ FailureOr parseEraseStmt(); FailureOr parseLetStmt(); FailureOr parseReplaceStmt(); + FailureOr parseRewriteStmt(); //===--------------------------------------------------------------------===// // Creation+Analysis @@ -246,6 +247,9 @@ FailureOr createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, MutableArrayRef replValues); + FailureOr + createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, + ast::CompoundStmt *rewriteBody); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -1156,6 +1160,9 @@ case Token::kw_replace: stmt = parseReplaceStmt(); break; + case Token::kw_rewrite: + stmt = parseRewriteStmt(); + break; default: stmt = parseExpr(); break; @@ -1307,6 +1314,32 @@ return createReplaceStmt(loc, *rootOp, replValues); } +FailureOr Parser::parseRewriteStmt() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_rewrite); + + // Parse the root operation. + FailureOr rootOp = parseExpr(); + if (failed(rootOp)) + return failure(); + + if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body"))) + return failure(); + + if (curToken.isNot(Token::l_brace)) + return emitError("expected `{` to start rewrite body"); + + // The rewrite body of this statement is within a rewrite context. + llvm::SaveAndRestore saveCtx(parserContext, + ParserContext::Rewrite); + + FailureOr rewriteBody = parseCompoundStmt(); + if (failed(rewriteBody)) + return failure(); + + return createRewriteStmt(loc, *rootOp, *rewriteBody); +} + //===----------------------------------------------------------------------===// // Creation+Analysis //===----------------------------------------------------------------------===// @@ -1647,6 +1680,20 @@ return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); } +FailureOr +Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp, + ast::CompoundStmt *rewriteBody) { + // Check that root is an Operation. + ast::Type rootType = rootOp->getType(); + if (!rootType.isa()) { + return emitError( + rootOp->getLoc(), + llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); + } + + return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody); +} + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -273,3 +273,37 @@ // CHECK: expected dialect namespace replace op<>(input: Value) with op<>; } + +// ----- + +//===----------------------------------------------------------------------===// +// `rewrite` +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected `Op` expression + rewrite attr<""> with { op; }; +} + +// ----- + +Pattern { + // CHECK: expected `with` before rewrite body + rewrite op<>; +} + +// ----- + +Pattern { + // CHECK: expected `{` to start rewrite body + rewrite op<> with; +} + +// ----- + +Pattern { + // CHECK: expected dialect namespace + rewrite root: Op with { + op<>; + }; +} diff --git a/mlir/test/mlir-pdll/Parser/stmt.pdll b/mlir/test/mlir-pdll/Parser/stmt.pdll --- a/mlir/test/mlir-pdll/Parser/stmt.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt.pdll @@ -182,3 +182,25 @@ Pattern { replace _: Op with (_: Value, _: ValueRange, op); } + +// ----- + +//===----------------------------------------------------------------------===// +// RewriteStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-RewriteStmt +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-CompoundStmt +// CHECK: |-OperationExpr {{.*}} Type> +// CHECK: `-ReplaceStmt {{.*}} +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `ReplValues` +// CHECK: `-OperationExpr {{.*}} Type> +Pattern { + rewrite root: Op with { + op; + replace root with op; + }; +}