diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -90,8 +90,15 @@ //===--------------------------------------------------------------------===// // Decls + /// This structure contains the set of pattern metadata that may be parsed. + struct ParsedPatternMetadata { + Optional benefit; + bool hasBoundedRecursion = false; + }; + FailureOr parseTopLevelDecl(); FailureOr parsePatternDecl(); + LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata); /// Check to see if a decl has already been defined with the given name, if /// one has emit and error and return failure. Returns success otherwise. @@ -153,11 +160,10 @@ /// Try to create a pattern decl with the given components, returning the /// Pattern on success. - FailureOr createPatternDecl(llvm::SMRange loc, - const ast::Name *name, - Optional benefit, - bool hasBoundedRecursion, - ast::CompoundStmt *body); + FailureOr + createPatternDecl(llvm::SMRange loc, const ast::Name *name, + const ParsedPatternMetadata &metadata, + ast::CompoundStmt *body); /// Try to create a variable decl with the given components, returning the /// Variable on success. @@ -450,9 +456,10 @@ consumeToken(Token::identifier); } - // TODO: Parse any pattern metadata. - Optional benefit; - bool hasBoundedRecursion = false; + // Parse any pattern metadata. + ParsedPatternMetadata metadata; + if (consumeIf(Token::kw_with) && failed(parsePatternDeclMetadata(metadata))) + return failure(); // Parse the pattern body. ast::CompoundStmt *body; @@ -482,7 +489,66 @@ "rewrite statement, but found trailing statements"); } - return createPatternDecl(loc, name, benefit, hasBoundedRecursion, body); + return createPatternDecl(loc, name, metadata, body); +} + +LogicalResult +Parser::parsePatternDeclMetadata(ParsedPatternMetadata &metadata) { + Optional benefitLoc; + Optional hasBoundedRecursionLoc; + + do { + if (curToken.isNot(Token::identifier)) + return emitError("expected pattern metadata identifier"); + StringRef metadataStr = curToken.getSpelling(); + llvm::SMRange metadataLoc = curToken.getLoc(); + consumeToken(Token::identifier); + + // Parse the benefit metadata: benefit() + if (metadataStr == "benefit") { + if (benefitLoc) { + return emitErrorAndNote(metadataLoc, + "pattern benefit has already been specified", + *benefitLoc, "see previous definition here"); + } + if (failed(parseToken(Token::l_paren, + "expected `(` before pattern benefit"))) + return failure(); + + uint16_t benefitValue = 0; + if (curToken.isNot(Token::integer)) + return emitError("expected integral pattern benefit"); + if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) + return emitError( + "expected pattern benefit to fit within a 16-bit integer"); + consumeToken(Token::integer); + + metadata.benefit = benefitValue; + benefitLoc = metadataLoc; + + if (failed( + parseToken(Token::r_paren, "expected `)` after pattern benefit"))) + return failure(); + continue; + } + + // Parse the bounded recursion metadata: recursion + if (metadataStr == "recursion") { + if (hasBoundedRecursionLoc) { + return emitErrorAndNote( + metadataLoc, + "pattern recursion metadata has already been specified", + *hasBoundedRecursionLoc, "see previous definition here"); + } + metadata.hasBoundedRecursion = true; + hasBoundedRecursionLoc = metadataLoc; + continue; + } + + return emitError(metadataLoc, "unknown pattern metadata"); + } while (consumeIf(Token::comma)); + + return success(); } FailureOr Parser::parseTypeConstraintExpr() { @@ -916,10 +982,10 @@ FailureOr Parser::createPatternDecl(llvm::SMRange loc, const ast::Name *name, - Optional benefit, bool hasBoundedRecursion, + const ParsedPatternMetadata &metadata, ast::CompoundStmt *body) { - return ast::PatternDecl::create(ctx, loc, name, benefit, hasBoundedRecursion, - body); + return ast::PatternDecl::create(ctx, loc, name, metadata.benefit, + metadata.hasBoundedRecursion, body); } FailureOr diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll --- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll @@ -24,3 +24,49 @@ erase root: Op; let value: Value; } + +// ----- + +//===----------------------------------------------------------------------===// +// Metadata +//===----------------------------------------------------------------------===// + +// CHECK: expected pattern metadata identifier +Pattern with {} + +// ----- + +// CHECK: unknown pattern metadata +Pattern with unknown {} + +// ----- + +// CHECK: expected `(` before pattern benefit +Pattern with benefit) {} + +// ----- + +// CHECK: expected integral pattern benefit +Pattern with benefit(foo) {} + +// ----- + +// CHECK: expected pattern benefit to fit within a 16-bit integer +Pattern with benefit(65536) {} + +// ----- + +// CHECK: expected `)` after pattern benefit +Pattern with benefit(1( {} + +// ----- + +// CHECK: pattern benefit has already been specified +// CHECK: see previous definition here +Pattern with benefit(1), benefit(1) {} + +// ----- + +// CHECK: pattern recursion metadata has already been specified +// CHECK: see previous definition here +Pattern with recursion, recursion {} diff --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll --- a/mlir/test/mlir-pdll/Parser/pattern.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern.pdll @@ -15,3 +15,11 @@ Pattern NamedPattern { erase _: Op; } + +// ----- + +// CHECK: Module +// CHECK: `-PatternDecl {{.*}} Name Benefit<10> Recursion +Pattern NamedPattern with benefit(10), recursion { + erase _: Op; +}