diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -214,20 +214,20 @@ /// within that pattern. DenseMap<Operation *, Qualifier *> patternToAnswer; - /// Returns true if this predicate is ordered before `other`, based on the - /// cost model. - bool operator<(const OrderedPredicate &other) const { + /// Returns true if this predicate is ordered before `rhs`, based on the cost + /// model. + bool operator<(const OrderedPredicate &rhs) const { // Sort by: - // * first and secondary order sums + // * higher first and secondary order sums // * lower depth - // * position dependency - // * predicate dependency. - auto *otherPos = other.position; - return std::make_tuple(other.primary, other.secondary, - otherPos->getIndex().size(), otherPos->getKind(), - other.question->getKind()) > - std::make_tuple(primary, secondary, position->getIndex().size(), - position->getKind(), question->getKind()); + // * lower position dependency + // * lower predicate dependency + auto *rhsPos = rhs.position; + return std::make_tuple(primary, secondary, rhsPos->getIndex().size(), + rhsPos->getKind(), rhs.question->getKind()) > + std::make_tuple(rhs.primary, rhs.secondary, + position->getIndex().size(), position->getKind(), + question->getKind()); } }; diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -143,3 +143,31 @@ pdl.rewrite %root with "rewriter" } } + +// ----- + +// CHECK-LABEL: module @predicate_ordering +module @predicate_ordering { + // Check that the result is checked for null first, before applying the + // constraint. The null check is prevalent in more patterns, so should be + // prioritized first. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] + // CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]] + // CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] + // CHECK: pdl_interp.apply_constraint "typeConstraint" [](%[[RESULT_TYPE]] + + pdl.pattern : benefit(1) { + %resultType = pdl.type + pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type) + %root, %result = pdl.operation -> %resultType + pdl.rewrite %root with "rewriter" + } + + pdl.pattern : benefit(1) { + %resultType = pdl.type + %apply, %applyRes = pdl.operation -> %resultType + pdl.rewrite %apply with "rewriter" + } +}