diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -249,6 +249,37 @@ EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {} }; +//===----------------------------------------------------------------------===// +// ReplaceStmt + +/// This statement represents the `replace` statement in PDLL. This statement +/// replace the given root operation with a set of values, corresponding roughly +/// to the PatternRewriter::replaceOp API. +class ReplaceStmt final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static ReplaceStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + ArrayRef replExprs); + + /// Return the replacement values of this statement. + MutableArrayRef getReplExprs() { + return {getTrailingObjects(), numReplExprs}; + } + ArrayRef getReplExprs() const { + return const_cast(this)->getReplExprs(); + } + +private: + ReplaceStmt(llvm::SMRange loc, Expr *rootOp, unsigned numReplExprs) + : Base(loc, rootOp), numReplExprs(numReplExprs) {} + + /// The number of replacement values within this statement. + unsigned numReplExprs; + + /// TrailingObject utilities. + friend class llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // Expr //===----------------------------------------------------------------------===// @@ -862,7 +893,7 @@ } inline bool OpRewriteStmt::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool Stmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -75,6 +75,7 @@ void printImpl(const CompoundStmt *stmt); void printImpl(const EraseStmt *stmt); void printImpl(const LetStmt *stmt); + void printImpl(const ReplaceStmt *stmt); void printImpl(const AttributeExpr *expr); void printImpl(const DeclRefExpr *expr); @@ -157,7 +158,7 @@ TypeSwitch(node) .Case< // Statements. - const CompoundStmt, const EraseStmt, const LetStmt, + const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, // Expressions. const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, @@ -190,6 +191,12 @@ printChildren(stmt->getVarDecl()); } +void NodePrinter::printImpl(const ReplaceStmt *stmt) { + os << "ReplaceStmt " << stmt << "\n"; + printChildren(stmt->getRootOpExpr()); + printChildren("ReplValues", stmt->getReplExprs()); +} + void NodePrinter::printImpl(const AttributeExpr *expr) { os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -81,6 +81,20 @@ return new (ctx.getAllocator().Allocate()) EraseStmt(loc, rootOp); } +//===----------------------------------------------------------------------===// +// ReplaceStmt + +ReplaceStmt *ReplaceStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + ArrayRef replExprs) { + unsigned allocSize = ReplaceStmt::totalSizeToAlloc(replExprs.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); + + ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); + std::uninitialized_copy(replExprs.begin(), replExprs.end(), + stmt->getReplExprs().begin()); + return stmt; +} + //===----------------------------------------------------------------------===// // AttributeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -138,6 +138,7 @@ FailureOr parseCompoundStmt(); FailureOr parseEraseStmt(); FailureOr parseLetStmt(); + FailureOr parseReplaceStmt(); //===--------------------------------------------------------------------===// // Creation+Analysis @@ -203,6 +204,9 @@ FailureOr createEraseStmt(llvm::SMRange loc, ast::Expr *rootOp); + FailureOr + createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, + MutableArrayRef replValues); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -1096,6 +1100,9 @@ case Token::kw_let: stmt = parseLetStmt(); break; + case Token::kw_replace: + stmt = parseReplaceStmt(); + break; default: stmt = parseExpr(); break; @@ -1201,6 +1208,48 @@ return ast::LetStmt::create(context, loc, *varDecl); } +FailureOr Parser::parseReplaceStmt() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_replace); + + // Parse the root operation expression. + FailureOr rootOp = parseExpr(); + if (failed(rootOp)) + return failure(); + + if (failed( + parseToken(Token::kw_with, "expected `with` after root operation"))) + return failure(); + + // Parse the replacement values. + SmallVector replValues; + if (consumeIf(Token::l_paren)) { + if (consumeIf(Token::r_paren)) { + return emitError( + loc, "expected at least one replacement value, consider using " + "`erase` if no replacement values are desired"); + } + + do { + FailureOr replExpr = parseExpr(); + if (failed(replExpr)) + return failure(); + replValues.emplace_back(*replExpr); + } while (consumeIf(Token::comma)); + + if (failed(parseToken(Token::r_paren, + "expected `)` after replacement values"))) + return failure(); + } else { + FailureOr replExpr = parseExpr(); + if (failed(replExpr)) + return failure(); + replValues.emplace_back(*replExpr); + } + + return createReplaceStmt(loc, *rootOp, replValues); +} + //===----------------------------------------------------------------------===// // Creation+Analysis //===----------------------------------------------------------------------===// @@ -1513,6 +1562,41 @@ return ast::EraseStmt::create(context, loc, rootOp); } +FailureOr +Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, + MutableArrayRef replValues) { + // Check that root is an Operation. + ast::Type rootType = rootOp->getType(); + if (!rootType.isa()) { + return emitError( + rootOp->getLoc(), + llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); + } + + // If there are multiple replacement values, we implicitly convert any Op + // expressions to the value form. + bool shouldConvertOpToValues = replValues.size() > 1; + for (ast::Expr *&replExpr : replValues) { + ast::Type replType = replExpr->getType(); + + // Check that replExpr is an Operation, Value, or ValueRange. + if (auto opType = replType.dyn_cast()) { + if (shouldConvertOpToValues) + replExpr = convertOpToValue(replExpr, opType); + continue; + } + + if (replType != valueTy && replType != valueRangeTy) { + return emitError(replExpr->getLoc(), + llvm::formatv("expected `Op`, `Value` or `ValueRange` " + "expression, but got `{0}`", + replType)); + } + } + + return ast::ReplaceStmt::create(context, loc, rootOp, replValues); +} + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -220,3 +220,49 @@ let foo: Attr; let foo: Attr; } + +// ----- + +//===----------------------------------------------------------------------===// +// `replace` +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected `Op` expression + replace attr<""> with attr<"">; +} + +// ----- + +Pattern { + // CHECK: expected `with` after root operation + replace op<>; +} + +// ----- + +Pattern { + // CHECK: expected `Op`, `Value` or `ValueRange` expression + replace op<> with attr<"">; +} + +// ----- + +Pattern { + // CHECK: expected `Op`, `Value` or `ValueRange` expression + replace op<> with (attr<"">); +} + +// ----- + +Pattern { + // CHECK: expected `)` after replacement values + replace op<>(input: Value) with (input; +} + +// ----- + +Pattern { + // CHECK: expected at least one replacement value, consider using `erase` if no replacement values are desired + replace op<>(input: Value) with (); +} diff --git a/mlir/test/mlir-pdll/Parser/stmt.pdll b/mlir/test/mlir-pdll/Parser/stmt.pdll --- a/mlir/test/mlir-pdll/Parser/stmt.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt.pdll @@ -153,3 +153,32 @@ let var: ValueRange<_: TypeRange>; erase _: Op; } + +// ----- + +//===----------------------------------------------------------------------===// +// ReplaceStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-ReplaceStmt +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: ReplValues +// CHECK: `-OperationExpr {{.*}} Type +Pattern { + replace _: Op with op<>; +} + +// ----- + +// CHECK: Module +// CHECK: `-ReplaceStmt +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: ReplValues +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-OperationExpr {{.*}} Type +Pattern { + replace _: Op with (_: Value, _: ValueRange, op<>); +}