diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -45,7 +45,9 @@ ``` }]; - let arguments = (ins StrAttr:$name, Variadic:$args); + let arguments = (ins StrAttr:$name, + Variadic:$args, + DefaultValuedAttr:$isNegated); let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -90,7 +90,8 @@ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional values. On success, this operation branches to the true destination, - otherwise the false destination is taken. + otherwise the false destination is taken. This behavior can be reversed + by setting the attribute `isNegated` to true. Example: @@ -101,7 +102,9 @@ ``` }]; - let arguments = (ins StrAttr:$name, Variadic:$args); + let arguments = (ins StrAttr:$name, + Variadic:$args, + DefaultValuedAttr:$isNegated); let assemblyFormat = [{ $name `(` $args `:` type($args) `)` attr-dict `->` successors }]; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -447,8 +447,9 @@ } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create(loc, cstQuestion->getName(), - args, success, failure); + builder.create( + loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success, + failure); break; } default: diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -450,7 +450,7 @@ /// Apply a parameterized constraint to multiple position values. struct ConstraintQuestion : public PredicateBase>, + std::tuple, bool>, Predicates::ConstraintQuestion> { using Base::Base; @@ -460,11 +460,20 @@ /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } + /// Return the negation status of the constraint. + bool getIsNegated() const { return std::get<2>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key))}); + alloc.copyInto(std::get<1>(key)), + std::get<2>(key)}); + } + + /// Returns a hash suitable for the given keytype. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); } }; @@ -664,9 +673,11 @@ } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos) { - return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), - TrueAnswer::get(uniquer)}; + Predicate getConstraint(StringRef name, ArrayRef pos, + bool isNegated) { + return { + ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)), + TrueAnswer::get(uniquer)}; } /// Create a predicate comparing a value with null. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -273,7 +273,7 @@ Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); PredicateBuilder::Predicate pred = - builder.getConstraint(op.getName(), allPositions); + builder.getConstraint(op.getName(), allPositions, op.getIsNegated()); predList.emplace_back(pos, pred); } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -79,6 +79,20 @@ // ----- +// CHECK-LABEL: module @negated_constraint +module @negated_constraint { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.apply_constraint "constraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + pdl.apply_native_constraint "constraint"(%root : !pdl.operation) {isNegated = true} + rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)