diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -250,6 +250,37 @@ EraseStmt(llvm::SMRange loc, Expr *rootOp) : Base(loc, rootOp) {} }; +//===----------------------------------------------------------------------===// +// ReplaceStmt + +/// This statement represents the `replace` statement in PDLL. This statement +/// replace the given root operation with a set of values, corresponding roughly +/// to the PatternRewriter::replaceOp API. +class ReplaceStmt final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static ReplaceStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + ArrayRef replExprs); + + /// Return the replacement values of this statement. + MutableArrayRef getReplExprs() { + return {getTrailingObjects(), numReplExprs}; + } + ArrayRef getReplExprs() const { + return const_cast(this)->getReplExprs(); + } + +private: + ReplaceStmt(llvm::SMRange loc, Expr *rootOp, unsigned numReplExprs) + : Base(loc, rootOp), numReplExprs(numReplExprs) {} + + /// The number of replacement values within this statement. + unsigned numReplExprs; + + /// TrailingObject utilities. + friend class llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // Expr //===----------------------------------------------------------------------===// @@ -878,7 +909,7 @@ } inline bool OpRewriteStmt::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool Stmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -75,6 +75,7 @@ void printImpl(const CompoundStmt *stmt); void printImpl(const EraseStmt *stmt); void printImpl(const LetStmt *stmt); + void printImpl(const ReplaceStmt *stmt); void printImpl(const AttributeExpr *expr); void printImpl(const DeclRefExpr *expr); @@ -157,7 +158,7 @@ TypeSwitch(node) .Case< // Statements. - const CompoundStmt, const EraseStmt, const LetStmt, + const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, // Expressions. const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, @@ -190,6 +191,12 @@ printChildren(stmt->getVarDecl()); } +void NodePrinter::printImpl(const ReplaceStmt *stmt) { + os << "ReplaceStmt " << stmt << "\n"; + printChildren(stmt->getRootOpExpr()); + printChildren("ReplValues", stmt->getReplExprs()); +} + void NodePrinter::printImpl(const AttributeExpr *expr) { os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -84,6 +84,20 @@ return new (ctx.getAllocator().Allocate()) EraseStmt(loc, rootOp); } +//===----------------------------------------------------------------------===// +// ReplaceStmt + +ReplaceStmt *ReplaceStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp, + ArrayRef replExprs) { + unsigned allocSize = ReplaceStmt::totalSizeToAlloc(replExprs.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); + + ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); + std::uninitialized_copy(replExprs.begin(), replExprs.end(), + stmt->getReplExprs().begin()); + return stmt; +} + //===----------------------------------------------------------------------===// // AttributeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -55,6 +55,9 @@ /// is allows a terminal operation rewrite statement but no other rewrite /// transformations. PatternMatch, + /// The parser is currently within a Rewrite, which disallows calls to + /// constraints, requires operation expressions to have names, etc. + Rewrite, }; //===--------------------------------------------------------------------===// @@ -147,8 +150,8 @@ FailureOr parseDeclRefExpr(StringRef name, llvm::SMRange loc); FailureOr parseIdentifierExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); - FailureOr parseOperationName(); - FailureOr parseWrappedOperationName(); + FailureOr parseOperationName(bool allowEmptyName = false); + FailureOr parseWrappedOperationName(bool allowEmptyName); FailureOr parseOperationExpr(); FailureOr parseTupleExpr(); FailureOr parseTypeExpr(); @@ -161,6 +164,7 @@ FailureOr parseCompoundStmt(); FailureOr parseEraseStmt(); FailureOr parseLetStmt(); + FailureOr parseReplaceStmt(); //===--------------------------------------------------------------------===// // Creation+Analysis @@ -240,6 +244,9 @@ FailureOr createEraseStmt(llvm::SMRange loc, ast::Expr *rootOp); + FailureOr + createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, + MutableArrayRef replValues); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -754,8 +761,10 @@ case Token::kw_Op: { consumeToken(Token::kw_Op); - // Parse an optional operation name. - FailureOr opName = parseWrappedOperationName(); + // Parse an optional operation name. If the name isn't provided, this refers + // to "any" operation. + FailureOr opName = + parseWrappedOperationName(/*allowEmptyName=*/true); if (failed(opName)) return failure(); @@ -927,13 +936,15 @@ return createMemberAccessExpr(parentExpr, memberName, loc); } -FailureOr Parser::parseOperationName() { +FailureOr Parser::parseOperationName(bool allowEmptyName) { llvm::SMRange loc = curToken.getLoc(); // Handle the case of an no operation name. - if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) - return ast::OpNameDecl::create(ctx, llvm::SMRange()); - + if (curToken.isNot(Token::identifier) && !curToken.isKeyword()) { + if (allowEmptyName) + return ast::OpNameDecl::create(ctx, llvm::SMRange()); + return emitError("expected dialect namespace"); + } StringRef name = curToken.getSpelling(); consumeToken(); @@ -954,11 +965,12 @@ return ast::OpNameDecl::create(ctx, ast::Name::create(ctx, name, loc)); } -FailureOr Parser::parseWrappedOperationName() { +FailureOr +Parser::parseWrappedOperationName(bool allowEmptyName) { if (!consumeIf(Token::less)) return ast::OpNameDecl::create(ctx, llvm::SMRange()); - FailureOr opNameDecl = parseOperationName(); + FailureOr opNameDecl = parseOperationName(allowEmptyName); if (failed(opNameDecl)) return failure(); @@ -980,8 +992,10 @@ // Parse the operation name. The name may be elided, in which case the // operation refers to "any" operation(i.e. a difference between `MyOp` and - // `Operation*`). - FailureOr opNameDecl = parseWrappedOperationName(); + // `Operation*`). Operation names within a rewrite context must be named. + bool allowEmptyName = parserContext != ParserContext::Rewrite; + FailureOr opNameDecl = + parseWrappedOperationName(allowEmptyName); if (failed(opNameDecl)) return failure(); @@ -1143,6 +1157,9 @@ case Token::kw_let: stmt = parseLetStmt(); break; + case Token::kw_replace: + stmt = parseReplaceStmt(); + break; default: stmt = parseExpr(); break; @@ -1248,6 +1265,52 @@ return ast::LetStmt::create(ctx, loc, *varDecl); } +FailureOr Parser::parseReplaceStmt() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_replace); + + // Parse the root operation expression. + FailureOr rootOp = parseExpr(); + if (failed(rootOp)) + return failure(); + + if (failed( + parseToken(Token::kw_with, "expected `with` after root operation"))) + return failure(); + + // The replacement portion of this statement is within a rewrite context. + llvm::SaveAndRestore saveCtx(parserContext, + ParserContext::Rewrite); + + // Parse the replacement values. + SmallVector replValues; + if (consumeIf(Token::l_paren)) { + if (consumeIf(Token::r_paren)) { + return emitError( + loc, "expected at least one replacement value, consider using " + "`erase` if no replacement values are desired"); + } + + do { + FailureOr replExpr = parseExpr(); + if (failed(replExpr)) + return failure(); + replValues.emplace_back(*replExpr); + } while (consumeIf(Token::comma)); + + if (failed(parseToken(Token::r_paren, + "expected `)` after replacement values"))) + return failure(); + } else { + FailureOr replExpr = parseExpr(); + if (failed(replExpr)) + return failure(); + replValues.emplace_back(*replExpr); + } + + return createReplaceStmt(loc, *rootOp, replValues); +} + //===----------------------------------------------------------------------===// // Creation+Analysis //===----------------------------------------------------------------------===// @@ -1553,6 +1616,41 @@ return ast::EraseStmt::create(ctx, loc, rootOp); } +FailureOr +Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp, + MutableArrayRef replValues) { + // Check that root is an Operation. + ast::Type rootType = rootOp->getType(); + if (!rootType.isa()) { + return emitError( + rootOp->getLoc(), + llvm::formatv("expected `Op` expression, but got `{0}`", rootType)); + } + + // If there are multiple replacement values, we implicitly convert any Op + // expressions to the value form. + bool shouldConvertOpToValues = replValues.size() > 1; + for (ast::Expr *&replExpr : replValues) { + ast::Type replType = replExpr->getType(); + + // Check that replExpr is an Operation, Value, or ValueRange. + if (replType.isa()) { + if (shouldConvertOpToValues) + replExpr = convertOpToValue(replExpr); + continue; + } + + if (replType != valueTy && replType != valueRangeTy) { + return emitError(replExpr->getLoc(), + llvm::formatv("expected `Op`, `Value` or `ValueRange` " + "expression, but got `{0}`", + replType)); + } + } + + return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues); +} + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -220,3 +220,56 @@ let foo: Attr; let foo: Attr; } + +// ----- + +//===----------------------------------------------------------------------===// +// `replace` +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected `Op` expression + replace attr<""> with attr<"">; +} + +// ----- + +Pattern { + // CHECK: expected `with` after root operation + replace op<>; +} + +// ----- + +Pattern { + // CHECK: expected `Op`, `Value` or `ValueRange` expression + replace op<> with attr<"">; +} + +// ----- + +Pattern { + // CHECK: expected `Op`, `Value` or `ValueRange` expression + replace op<> with (attr<"">); +} + +// ----- + +Pattern { + // CHECK: expected `)` after replacement values + replace op<>(input: Value) with (input; +} + +// ----- + +Pattern { + // CHECK: expected at least one replacement value, consider using `erase` if no replacement values are desired + replace op<>(input: Value) with (); +} + +// ----- + +Pattern { + // CHECK: expected dialect namespace + replace op<>(input: Value) with op<>; +} diff --git a/mlir/test/mlir-pdll/Parser/stmt.pdll b/mlir/test/mlir-pdll/Parser/stmt.pdll --- a/mlir/test/mlir-pdll/Parser/stmt.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt.pdll @@ -153,3 +153,32 @@ let var: ValueRange<_: TypeRange>; erase _: Op; } + +// ----- + +//===----------------------------------------------------------------------===// +// ReplaceStmt +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-ReplaceStmt +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: ReplValues +// CHECK: `-OperationExpr {{.*}} Type> +Pattern { + replace _: Op with op; +} + +// ----- + +// CHECK: Module +// CHECK: `-ReplaceStmt +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: ReplValues +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type +// CHECK: `-OperationExpr {{.*}} Type> +Pattern { + replace _: Op with (_: Value, _: ValueRange, op); +}