diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -390,7 +390,8 @@ private llvm::TrailingObjects { public: static CallExpr *create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType); + ArrayRef arguments, Type resultType, + bool isNegated = false); /// Return the callable of this call. Expr *getCallableExpr() const { return callable; } @@ -403,9 +404,14 @@ return const_cast(this)->getArguments(); } + /// Returns whether the result of this call is to be negated. + bool getIsNegated() const { return isNegated; } + private: - CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs) - : Base(loc, type), callable(callable), numArgs(numArgs) {} + CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs, + bool isNegated) + : Base(loc, type), callable(callable), numArgs(numArgs), + isNegated(isNegated) {} /// The callable of this call. Expr *callable; @@ -415,6 +421,9 @@ /// TrailingObject utilities. friend llvm::TrailingObjects; + + // Is the result of this call to be negated. + bool isNegated; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -225,7 +225,10 @@ void NodePrinter::printImpl(const CallExpr *expr) { os << "CallExpr " << expr << " Type<"; print(expr->getType()); - os << ">\n"; + os << ">"; + if (expr->getIsNegated()) + os << " Negated"; + os << "\n"; printChildren(expr->getCallableExpr()); printChildren("Arguments", expr->getArguments()); } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -266,12 +266,13 @@ //===----------------------------------------------------------------------===// CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType) { + ArrayRef arguments, Type resultType, + bool isNegated) { unsigned allocSize = CallExpr::totalSizeToAlloc(arguments.size()); void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); - CallExpr *expr = - new (rawData) CallExpr(loc, resultType, callable, arguments.size()); + CallExpr *expr = new (rawData) + CallExpr(loc, resultType, callable, arguments.size(), isNegated); std::uninitialized_copy(arguments.begin(), arguments.end(), expr->getArguments().begin()); return expr; diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -103,12 +103,14 @@ Value genExprImpl(const ast::TypeExpr *expr); SmallVector genConstraintCall(const ast::UserConstraintDecl *decl, - Location loc, ValueRange inputs); + Location loc, ValueRange inputs, + bool isNegated = false); SmallVector genRewriteCall(const ast::UserRewriteDecl *decl, Location loc, ValueRange inputs); template SmallVector genConstraintOrRewriteCall(const T *decl, Location loc, - ValueRange inputs); + ValueRange inputs, + bool isNegated = false); //===--------------------------------------------------------------------===// // Fields @@ -419,7 +421,7 @@ // Generate the PDL based on the type of callable. const ast::Decl *callable = callableExpr->getDecl(); if (const auto *decl = dyn_cast(callable)) - return genConstraintCall(decl, loc, arguments); + return genConstraintCall(decl, loc, arguments, expr->getIsNegated()); if (const auto *decl = dyn_cast(callable)) return genRewriteCall(decl, loc, arguments); llvm_unreachable("unhandled CallExpr callable"); @@ -547,15 +549,15 @@ SmallVector CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, - ValueRange inputs) { + ValueRange inputs, bool isNegated) { // Apply any constraints defined on the arguments to the input values. for (auto it : llvm::zip(decl->getInputs(), inputs)) applyVarConstraints(std::get<0>(it), std::get<1>(it)); // Generate the constraint call. SmallVector results = - genConstraintOrRewriteCall(decl, loc, - inputs); + genConstraintOrRewriteCall( + decl, loc, inputs, isNegated); // Apply any constraints defined on the results of the constraint. for (auto it : llvm::zip(decl->getResults(), results)) @@ -570,9 +572,9 @@ } template -SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, - Location loc, - ValueRange inputs) { +SmallVector +CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, + ValueRange inputs, bool isNegated) { const ast::CompoundStmt *cstBody = decl->getBody(); // If the decl doesn't have a statement body, it is a native decl. @@ -585,8 +587,10 @@ } else { resultTypes.push_back(genType(declResultType)); } - Operation *pdlOp = builder.create( + PDLOpT pdlOp = builder.create( loc, resultTypes, decl->getName().getName(), inputs); + if (isNegated && std::is_same_v) + cast(pdlOp).setIsNegated(true); return pdlOp->getResults(); } diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -57,6 +57,7 @@ kw_erase, kw_let, kw_Constraint, + kw_not, kw_Op, kw_OpName, kw_Pattern, diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -315,6 +315,7 @@ .Case("erase", Token::kw_erase) .Case("let", Token::kw_let) .Case("Constraint", Token::kw_Constraint) + .Case("not", Token::kw_not) .Case("op", Token::kw_op) .Case("Op", Token::kw_Op) .Case("OpName", Token::kw_OpName) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -315,12 +315,14 @@ /// Identifier expressions. FailureOr parseAttributeExpr(); - FailureOr parseCallExpr(ast::Expr *parentExpr); + FailureOr parseCallExpr(ast::Expr *parentExpr, + bool isNegated = false); FailureOr parseDeclRefExpr(StringRef name, SMRange loc); FailureOr parseIdentifierExpr(); FailureOr parseInlineConstraintLambdaExpr(); FailureOr parseInlineRewriteLambdaExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); + FailureOr parseNegatedExpr(); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); FailureOr @@ -405,7 +407,8 @@ FailureOr createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments); + MutableArrayRef arguments, + bool isNegated = false); FailureOr createDeclRefExpr(SMRange loc, ast::Decl *decl); FailureOr createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, @@ -1805,6 +1808,9 @@ case Token::kw_Constraint: lhsExpr = parseInlineConstraintLambdaExpr(); break; + case Token::kw_not: + lhsExpr = parseNegatedExpr(); + break; case Token::identifier: lhsExpr = parseIdentifierExpr(); break; @@ -1866,7 +1872,8 @@ return ast::AttributeExpr::create(ctx, loc, attrExpr); } -FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { +FailureOr Parser::parseCallExpr(ast::Expr *parentExpr, + bool isNegated) { consumeToken(Token::l_paren); // Parse the arguments of the call. @@ -1890,7 +1897,7 @@ if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) return failure(); - return createCallExpr(loc, parentExpr, arguments); + return createCallExpr(loc, parentExpr, arguments, isNegated); } FailureOr Parser::parseDeclRefExpr(StringRef name, SMRange loc) { @@ -1959,6 +1966,17 @@ return createMemberAccessExpr(parentExpr, memberName, loc); } +FailureOr Parser::parseNegatedExpr() { + consumeToken(Token::kw_not); + // Only native constraints are supported after negation + if (!curToken.is(Token::identifier)) + return emitError("expected native constraint"); + FailureOr identifierExpr = parseIdentifierExpr(); + if (failed(identifierExpr)) + return failure(); + return parseCallExpr(*identifierExpr, /*isNegated = */ true); +} + FailureOr Parser::parseOperationName(bool allowEmptyName) { SMRange loc = curToken.getLoc(); @@ -2672,7 +2690,7 @@ FailureOr Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments) { + MutableArrayRef arguments, bool isNegated) { ast::Type parentType = parentExpr->getType(); ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); @@ -2686,8 +2704,14 @@ if (isa(callableDecl)) return emitError( loc, "unable to invoke `Constraint` within a rewrite section"); - } else if (isa(callableDecl)) { - return emitError(loc, "unable to invoke `Rewrite` within a match section"); + if (isNegated) + return emitError(loc, "unable to negate a Rewrite"); + } else { + if (isa(callableDecl)) + return emitError(loc, + "unable to invoke `Rewrite` within a match section"); + if (isNegated && cast(callableDecl)->getBody()) + return emitError(loc, "unable to negate non native constraints"); } // Verify the arguments of the call. @@ -2718,7 +2742,7 @@ } return ast::CallExpr::create(ctx, loc, parentExpr, arguments, - callableDecl->getResultType()); + callableDecl->getResultType(), isNegated); } FailureOr Parser::createDeclRefExpr(SMRange loc, diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -36,6 +36,20 @@ // ----- +// CHECK: pdl.pattern @TestExternalNegatedCall +// CHECK: %[[ROOT:.*]] = operation +// CHECK: apply_native_constraint "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} +// CHECK: rewrite %[[ROOT]] +// CHECK: erase %[[ROOT]] +Constraint TestConstraint(op: Op); +Pattern TestExternalNegatedCall { + let root = op : Op; + not TestConstraint(root); + erase root; +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -173,6 +173,16 @@ // ----- +Constraint Foo(op: Op) {} + +Pattern { + // CHECK: unable to negate non native constraints + let root = op<>; + not Foo(root); +} + +// ----- + Rewrite Foo(); Pattern { @@ -183,6 +193,18 @@ // ----- +Rewrite Foo(op: Op); + +Pattern { + // CHECK: unable to negate a Rewrite + let root = op<>; + rewrite root with { + not Foo(root); + } +} + +// ----- + Pattern { // CHECK: expected expression let tuple = (10 = _: Value); diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -50,6 +50,22 @@ // ----- +// CHECK: Module {{.*}} +// CHECK: -UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `-PatternDecl {{.*}} +// CHECK: -CallExpr {{.*}} Type> Negated +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint TestConstraint(op: Op); + +Pattern { + let inputOp = op; + not TestConstraint(inputOp); + erase inputOp; +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===//