diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -89,6 +89,8 @@ FailureOr parseTopLevelDecl(); FailureOr parsePatternDecl(); + LogicalResult parsePatternDeclMetadata(Optional &benefit, + bool &hasBoundedRecursion); LogicalResult checkDefineNamedDecl(const ast::Name &name); FailureOr @@ -417,9 +419,12 @@ consumeToken(Token::identifier); } - // TODO: Parse any pattern metadata. + // Parse any pattern metadata. Optional benefit; bool hasBoundedRecursion = false; + if (consumeIf(Token::kw_with) && + failed(parsePatternDeclMetadata(benefit, hasBoundedRecursion))) + return failure(); // Parse the pattern body. ast::CompoundStmt *body; @@ -452,6 +457,65 @@ return createPatternDecl(loc, name, benefit, hasBoundedRecursion, body); } +LogicalResult Parser::parsePatternDeclMetadata(Optional &benefit, + bool &hasBoundedRecursion) { + Optional benefitLoc; + Optional hasBoundedRecursionLoc; + + do { + if (curToken.isNot(Token::identifier)) + return emitError("expected pattern metadata identifier"); + StringRef metadata = curToken.getSpelling(); + llvm::SMRange metadataLoc = curToken.getLoc(); + consumeToken(Token::identifier); + + // Parse the benefit metadata: benefit() + if (metadata == "benefit") { + if (benefitLoc) { + return emitErrorAndNote(metadataLoc, + "pattern benefit has already been specified", + *benefitLoc, "see previous definition here"); + } + if (failed(parseToken(Token::l_paren, + "expected `(` before pattern benefit"))) + return failure(); + + uint16_t benefitValue = 0; + if (curToken.isNot(Token::integer)) + return emitError("expected integral pattern benefit"); + if (curToken.getSpelling().getAsInteger(/*Radix=*/10, benefitValue)) + return emitError( + "expected pattern benefit to fit within a 16-bit integer"); + consumeToken(Token::integer); + + benefit = benefitValue; + benefitLoc = metadataLoc; + + if (failed( + parseToken(Token::r_paren, "expected `)` after pattern benefit"))) + return failure(); + continue; + } + + // Parse the bounded recursion metadata: recursion + if (metadata == "recursion") { + if (hasBoundedRecursionLoc) { + return emitErrorAndNote( + metadataLoc, + "pattern recursion metadata has already been specified", + *hasBoundedRecursionLoc, "see previous definition here"); + } + hasBoundedRecursion = true; + hasBoundedRecursionLoc = metadataLoc; + continue; + } + + return emitError(metadataLoc, "unknown pattern metadata"); + } while (consumeIf(Token::comma)); + + return success(); +} + FailureOr Parser::parseTypeConstraintExpr() { consumeToken(Token::less); diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll --- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll @@ -24,3 +24,49 @@ erase root: Op; let value: Value; } + +// ----- + +//===----------------------------------------------------------------------===// +// Metadata +//===----------------------------------------------------------------------===// + +// CHECK: expected pattern metadata identifier +Pattern with {} + +// ----- + +// CHECK: unknown pattern metadata +Pattern with unknown {} + +// ----- + +// CHECK: expected `(` before pattern benefit +Pattern with benefit) {} + +// ----- + +// CHECK: expected integral pattern benefit +Pattern with benefit(foo) {} + +// ----- + +// CHECK: expected pattern benefit to fit within a 16-bit integer +Pattern with benefit(65536) {} + +// ----- + +// CHECK: expected `)` after pattern benefit +Pattern with benefit(1( {} + +// ----- + +// CHECK: pattern benefit has already been specified +// CHECK: see previous definition here +Pattern with benefit(1), benefit(1) {} + +// ----- + +// CHECK: pattern recursion metadata has already been specified +// CHECK: see previous definition here +Pattern with recursion, recursion {} \ No newline at end of file diff --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll --- a/mlir/test/mlir-pdll/Parser/pattern.pdll +++ b/mlir/test/mlir-pdll/Parser/pattern.pdll @@ -15,3 +15,11 @@ Pattern NamedPattern { erase _: Op; } + +// ----- + +// CHECK: Module +// CHECK: `-PatternDecl {{.*}} Name Benefit<10> Recursion +Pattern NamedPattern with benefit(10), recursion { + erase _: Op; +}