Index: mlir/include/mlir/Dialect/PDL/IR/PDLOps.td =================================================================== --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,7 +35,7 @@ let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. + entities and optionally return a number of values. Example: @@ -46,7 +46,8 @@ }]; let arguments = (ins StrAttr:$name, Variadic:$args); - let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; + let results = (outs Variadic:$results); + let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; let hasVerifier = 1; } Index: mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td =================================================================== --- mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -89,8 +89,9 @@ let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. On success, this operation branches to the true destination, - otherwise the false destination is taken. + values. The constraint function may return any number of results. + On success, this operation branches to the true destination, otherwise + the false destination is taken. Example: @@ -102,8 +103,9 @@ }]; let arguments = (ins StrAttr:$name, Variadic:$args); + let results = (outs Variadic:$results); let assemblyFormat = [{ - $name `(` $args `:` type($args) `)` attr-dict `->` successors + $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors }]; } Index: mlir/include/mlir/IR/PatternMatch.h =================================================================== --- mlir/include/mlir/IR/PatternMatch.h +++ mlir/include/mlir/IR/PatternMatch.h @@ -1572,6 +1572,19 @@ std::forward(constraintFn))); } + /// Register a constraint function that produces results with PDL. A + /// constraint function with results uses the same registry as rewrite + /// functions. It may be specified as follows: + /// + /// * `LogicalResult (PatternRewriter &, PDLResultList &, + /// ArrayRef)` + /// + /// The arguments of the constraint function are passed via the low-level + /// PDLValue form, and the results are manually appended to the given result + /// list. + void registerConstraintFunctionWithResults(StringRef name, + PDLRewriteFunction constraintFn); + /// Register a rewrite function with PDL. A rewrite function may be specified /// in one of two ways: /// Index: mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp =================================================================== --- mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -148,6 +148,12 @@ /// A mapping between pattern operations and the corresponding configuration /// set. DenseMap *configMap; + + /// A mapping from a constraint question and result index that together + /// refer to a value created by a constraint to the temporary placeholder + /// values created for them. + std::multimap, Value> + constraintResultMap; }; } // namespace @@ -364,6 +370,21 @@ loc, cast(rawTypeAttr)); break; } + case Predicates::ConstraintResultPos: { + // The corresponding pdl.ApplyNativeConstraint op has already been deleted + // and the new pdl_interp.ApplyConstraint has not been created yet. To + // enable referring to results created by this operation we build a + // placeholder value that will be replaced when the actual + // pdl_interp.ApplyConstraint operation is created. + auto *constrResPos = cast(pos); + Value placeholderValue = builder.create( + loc, StringAttr::get(builder.getContext(), "placeholder")); + constraintResultMap.insert( + {{constrResPos->getQuestion(), constrResPos->getIndex()}, + placeholderValue}); + value = placeholderValue; + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -447,8 +468,25 @@ } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create(loc, cstQuestion->getName(), - args, success, failure); + auto applyConstraintOp = builder.create( + loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, + success, failure); + // Replace the generated placeholders with the results of the constraint and + // erase them + for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { + std::pair substitutionKey = { + cstQuestion, result.index()}; + // Check if there are substitutions to perform. If the result is never + // used or multiple calls to the same constraint have been merged, + // no substitutions will have been generated for this specific op. + auto range = constraintResultMap.equal_range(substitutionKey); + std::for_each(range.first, range.second, [&](const auto &elem) { + Value placeholder = elem.second; + placeholder.replaceAllUsesWith(result.value()); + placeholder.getDefiningOp()->erase(); + }); + constraintResultMap.erase(substitutionKey); + } break; } default: Index: mlir/lib/Conversion/PDLToPDLInterp/Predicate.h =================================================================== --- mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -47,6 +47,7 @@ OperandPos, OperandGroupPos, AttributePos, + ConstraintResultPos, ResultPos, ResultGroupPos, TypePos, @@ -187,6 +188,28 @@ using PredicateBase::PredicateBase; }; +//===----------------------------------------------------------------------===// +// ConstraintPosition + +struct ConstraintQuestion; + +/// A position describing the result of a native constraint. It saves the +/// corresponding ConstraintQuestion and result index to enable referring +/// back to them +struct ConstraintPosition + : public PredicateBase, + Predicates::ConstraintResultPos> { + using PredicateBase::PredicateBase; + + /// Returns the ConstraintQuestion to enable keeping track of the native + /// constraint this position stems from. + ConstraintQuestion *getQuestion() const { return key.first; } + + // Returns the result index of this position + unsigned getIndex() const { return key.second; } +}; + //===----------------------------------------------------------------------===// // ForEachPosition @@ -447,11 +470,13 @@ : public PredicateBase {}; -/// Apply a parameterized constraint to multiple position values. +/// Apply a parameterized constraint to multiple position values and possibly +/// produce results. struct ConstraintQuestion - : public PredicateBase>, - Predicates::ConstraintQuestion> { + : public PredicateBase< + ConstraintQuestion, Qualifier, + std::tuple, ArrayRef>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -460,11 +485,15 @@ /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } + /// Return the result types of the constraint. + ArrayRef getResultTypes() const { return std::get<2>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key))}); + alloc.copyInto(std::get<1>(key)), + alloc.copyInto(std::get<2>(key))}); } }; @@ -517,6 +546,7 @@ // Register the types of Positions with the uniquer. registerParametricStorageType(); registerParametricStorageType(); + registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); registerParametricStorageType(); @@ -579,6 +609,12 @@ return OperationPosition::get(uniquer, p); } + // Returns a position for a new value created by a constraint. + ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, + unsigned index) { + return ConstraintPosition::get(uniquer, std::make_pair(q, index)); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -664,8 +700,10 @@ } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos) { - return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), + Predicate getConstraint(StringRef name, ArrayRef args, + ArrayRef resultTypes) { + return {ConstraintQuestion::get(uniquer, + std::make_tuple(name, args, resultTypes)), TrueAnswer::get(uniquer)}; } Index: mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp =================================================================== --- mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -272,8 +272,16 @@ // Push the constraint to the furthest position. Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); - PredicateBuilder::Predicate pred = - builder.getConstraint(op.getName(), allPositions); + ResultRange results = op.getResults(); + PredicateBuilder::Predicate pred = builder.getConstraint( + op.getName(), allPositions, SmallVector(results.getTypes())); + + // for each result register a position so it can be used later + for (auto result : llvm::enumerate(results)) { + ConstraintQuestion *q = cast(pred.first); + ConstraintPosition *pos = builder.getConstraintPosition(q, result.index()); + inputs[result.value()] = pos; + } predList.emplace_back(pos, pred); } Index: mlir/lib/Dialect/PDL/IR/PDL.cpp =================================================================== --- mlir/lib/Dialect/PDL/IR/PDL.cpp +++ mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -94,6 +94,12 @@ LogicalResult ApplyNativeConstraintOp::verify() { if (getNumOperands() == 0) return emitOpError("expected at least one argument"); + if (llvm::any_of(getResults(), [](OpResult result) { + return result.getType().isa(); + })) { + return emitOpError( + "returning an operation from a constraint is not supported"); + } return success(); } Index: mlir/lib/IR/PatternMatch.cpp =================================================================== --- mlir/lib/IR/PatternMatch.cpp +++ mlir/lib/IR/PatternMatch.cpp @@ -204,6 +204,15 @@ constraintFunctions.try_emplace(name, std::move(constraintFn)); } +void PDLPatternModule::registerConstraintFunctionWithResults( + StringRef name, PDLRewriteFunction constraintFn) { + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `rewriteFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // rewrite. + registerRewriteFunction(name, std::move(constraintFn)); +} + void PDLPatternModule::registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) { // TODO: Is it possible to diagnose when `name` is already registered to Index: mlir/lib/Rewrite/ByteCode.cpp =================================================================== --- mlir/lib/Rewrite/ByteCode.cpp +++ mlir/lib/Rewrite/ByteCode.cpp @@ -769,10 +769,35 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { - assert(constraintToMemIndex.count(op.getName()) && - "expected index for constraint function"); - writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + // Constraints that should return a value have to be registered as rewrites. + // If a constraint and a rewrite of similar name are registered the + // constraint takes precedence + ResultRange results = op.getResults(); + if (results.size() == 0 && constraintToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); + } else if (results.size() > 0 && + externalRewriterToMemIndex.count(op.getName()) != 0) { + writer.append(OpCode::ApplyConstraint, + externalRewriterToMemIndex[op.getName()]); + } else { + assert(true && "expected index for constraint function, make sure it is " + "registered properly. Note that native constraints with " + "results have to be registered using " + "PDLPatternModule::registerConstraintFunctionWithResults."); + } writer.appendPDLValueList(op.getArgs()); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // We record the expected kind of the result, so that we can provide extra + // verification of the native rewrite function and handle the failure case + // of constraints accordingly. + writer.appendPDLValueKind(result); + + // Range results also need to append the range storage index. + if (result.getType().isa()) + writer.append(getRangeStorageIndex(result)); + writer.append(result); + } writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -785,14 +810,12 @@ ResultRange results = op.getResults(); writer.append(ByteCodeField(results.size())); for (Value result : results) { - // In debug mode we also record the expected kind of the result, so that we + // We record the expected kind of the result, so that we // can provide extra verification of the native rewrite function. -#ifndef NDEBUG writer.appendPDLValueKind(result); -#endif // Range results also need to append the range storage index. - if (isa(result.getType())) + if (result.getType().isa()) writer.append(getRangeStorageIndex(result)); writer.append(result); } @@ -1075,6 +1098,28 @@ // ByteCode Execution namespace { +/// This class is an instantiation of the PDLResultList that provides access to +/// the returned results. This API is not on `PDLResultList` to avoid +/// overexposing access to information specific solely to the ByteCode. +class ByteCodeRewriteResultList : public PDLResultList { +public: + ByteCodeRewriteResultList(unsigned maxNumResults) + : PDLResultList(maxNumResults) {} + + /// Return the list of PDL results. + MutableArrayRef getResults() { return results; } + + /// Return the type ranges allocated by this list. + MutableArrayRef> getAllocatedTypeRanges() { + return allocatedTypeRanges; + } + + /// Return the value ranges allocated by this list. + MutableArrayRef> getAllocatedValueRanges() { + return allocatedValueRanges; + } +}; + /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: @@ -1152,6 +1197,11 @@ void executeSwitchType(); void executeSwitchTypes(); + // Helper for native constraints and rewrites + void processNativeFunResults(ByteCodeRewriteResultList &results, + unsigned numResults, + LogicalResult &rewriteResult); + /// Pushes a code iterator to the stack. void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } @@ -1224,6 +1274,9 @@ return T::getFromOpaquePointer(pointer); } + /// Skip a list of values on the bytecode buffer. + void skip(size_t skipN) { curCodeIt += skipN; } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -1380,43 +1433,49 @@ ArrayRef constraintFunctions; ArrayRef rewriteFunctions; }; - -/// This class is an instantiation of the PDLResultList that provides access to -/// the returned results. This API is not on `PDLResultList` to avoid -/// overexposing access to information specific solely to the ByteCode. -class ByteCodeRewriteResultList : public PDLResultList { -public: - ByteCodeRewriteResultList(unsigned maxNumResults) - : PDLResultList(maxNumResults) {} - - /// Return the list of PDL results. - MutableArrayRef getResults() { return results; } - - /// Return the type ranges allocated by this list. - MutableArrayRef> getAllocatedTypeRanges() { - return allocatedTypeRanges; - } - - /// Return the value ranges allocated by this list. - MutableArrayRef> getAllocatedValueRanges() { - return allocatedValueRanges; - } -}; } // namespace void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); - const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ByteCodeField fun_idx = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n"; }); - // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(rewriter, args))); + ByteCodeField numResults = read(); + if (numResults == 0) { + const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx]; + LogicalResult rewriteResult = constraintFn(rewriter, args); + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); + return; + } + const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; + ByteCodeRewriteResultList results(numResults); + LogicalResult rewriteResult = constraintFn(rewriter, results, args); + ArrayRef constraintResults = results.getResults(); + LLVM_DEBUG({ + if (succeeded(rewriteResult)) { + llvm::dbgs() << " * Constraint succeeded\n"; + llvm::dbgs() << " * Results: "; + llvm::interleaveComma(constraintResults, llvm::dbgs()); + llvm::dbgs() << "\n"; + } else { + llvm::dbgs() << " * Constraint failed\n"; + } + }); + assert((failed(rewriteResult) || constraintResults.size() == numResults) && + "native PDL rewrite function succeeded but returned " + "unexpected number of results"); + processNativeFunResults(results, numResults, rewriteResult); + + // Depending on the constraint jump to the proper destination. + selectJump(succeeded(rewriteResult)); } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { @@ -1438,15 +1497,39 @@ assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); - // Store the results in the bytecode memory. - for (PDLValue &result : results.getResults()) { - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + processNativeFunResults(results, numResults, rewriteResult); -// In debug mode we also verify the expected kind of the result. -#ifndef NDEBUG - assert(result.getKind() == read() && - "native PDL rewrite function returned an unexpected type of result"); -#endif + if (failed(rewriteResult)) { + LLVM_DEBUG(llvm::dbgs() << " - Failed"); + return failure(); + } + return success(); +} + +void ByteCodeExecutor::processNativeFunResults( + ByteCodeRewriteResultList &results, unsigned numResults, + LogicalResult &rewriteResult) { + // Store the results in the bytecode memory or handle missing results on + // failure. + for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) { + PDLValue::Kind resultKind = read(); + + // Skip the according number of values on the buffer on failure and exit + // early as there are no results to process. + if (failed(rewriteResult)) { + if (resultKind == PDLValue::Kind::TypeRange || + resultKind == PDLValue::Kind::ValueRange) { + skip(2); + } else { + skip(1); + } + return; + } + PDLValue result = results.getResults()[resultIdx]; + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + assert(result.getKind() == resultKind && + "native PDL rewrite function returned an unexpected type of " + "result"); // If the result is a range, we need to copy it over to the bytecodes // range memory. @@ -1469,13 +1552,6 @@ allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); - - // Process the result of the rewrite. - if (failed(rewriteResult)) { - LLVM_DEBUG(llvm::dbgs() << " - Failed"); - return failure(); - } - return success(); } void ByteCodeExecutor::executeAreEqual() { Index: mlir/lib/Tools/PDLL/Parser/Parser.cpp =================================================================== --- mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -1359,12 +1359,6 @@ if (failed(parseToken(Token::semicolon, "expected `;` after native declaration"))) return failure(); - // TODO: PDL should be able to support constraint results in certain - // situations, we should revise this. - if (std::is_same::value && !results.empty()) { - return emitError( - "native Constraints currently do not support returning results"); - } return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); } Index: mlir/python/mlir/dialects/_pdl_ops_ext.py =================================================================== --- mlir/python/mlir/dialects/_pdl_ops_ext.py +++ mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -21,6 +21,7 @@ def __init__( self, name: Union[str, StringAttr], + results: Sequence[Type] = None, args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, *, loc=None, @@ -28,8 +29,11 @@ ): if args is None: args = [] + if results is None: + results = [] args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) + results = _get_values(results) + super().__init__(results, name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: Index: mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir =================================================================== --- mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -79,6 +79,57 @@ // ----- +// CHECK-LABEL: module @constraint_with_result +module @constraint_with_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + +// CHECK-LABEL: module @constraint_with_unused_result +module @constraint_with_unused_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @constraint_with_result_multiple +module @constraint_with_result_multiple { + // check that native constraints work as expected even when multiple identical constraints are fused + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr" + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) Index: mlir/test/Dialect/PDL/ops.mlir =================================================================== --- mlir/test/Dialect/PDL/ops.mlir +++ mlir/test/Dialect/PDL/ops.mlir @@ -134,6 +134,24 @@ // ----- +pdl.pattern @apply_constraint_with_no_results : benefit(1) { + %root = operation + apply_native_constraint "NativeConstraint"(%root : !pdl.operation) + rewrite %root with "rewriter" +} + +// ----- + +pdl.pattern @apply_constraint_with_results : benefit(1) { + %root = operation + %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute + rewrite %root { + apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) + } +} + +// ----- + pdl.pattern @attribute_with_dict : benefit(1) { %root = operation rewrite %root { Index: mlir/test/Rewrite/pdl-bytecode.mlir =================================================================== --- mlir/test/Rewrite/pdl-bytecode.mlir +++ mlir/test/Rewrite/pdl-bytecode.mlir @@ -72,6 +72,109 @@ // ----- +// Test returning an attribute from a native constraint +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %attr = pdl_interp.apply_constraint "op_constr_return_attr"(%root : !pdl.operation) : !pdl.attribute -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root, %attr : !pdl.operation, !pdl.attribute) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %attr : !pdl.attribute) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" {"attr" = %attr} + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_3 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern"() {attr = "test.success"} +module @ir attributes { test.apply_constraint_3 } { + "test.failure_op"() : () -> () + "test.success_op"() : () -> () +} + +// ----- + +// Test returning a type from a native constraint. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_4 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern"() : () -> f32 +module @ir attributes { test.apply_constraint_4 } { + "test.failure_op"() : () -> () + "test.success_op"() : () -> () +} + +// ----- + +// Test success and failure cases of native constraints with pdl.range results. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %num_results = pdl_interp.create_attribute 2 : i32 + %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_5 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32) +module @ir attributes { test.apply_constraint_5 } { + "test.failure_op"() : () -> () + "test.success_op"() : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp //===----------------------------------------------------------------------===// Index: mlir/test/lib/Rewrite/TestPDLByteCode.cpp =================================================================== --- mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -30,6 +30,50 @@ return success(); } +// Custom constraint that returns a value if the op is named test.success_op +static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + auto *op = args[0].cast(); + if (op->getName().getStringRef() == "test.success_op") { + StringAttr customAttr = rewriter.getStringAttr("test.success"); + results.push_back(customAttr); + return success(); + } + return failure(); +} + +// Custom constraint that returns a type if the op is named test.success_op +static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + auto *op = args[0].cast(); + if (op->getName().getStringRef() == "test.success_op") { + results.push_back(rewriter.getF32Type()); + return success(); + } + return failure(); +} + +// Custom constraint that returns a type range of variable length if the op is +// named test.success_op +static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + auto *op = args[0].cast(); + int numTypes = args[1].cast().cast().getInt(); + + if (op->getName().getStringRef() == "test.success_op") { + SmallVector types; + for (int i = 0; i < numTypes; i++) { + types.push_back(rewriter.getF32Type()); + } + results.push_back(TypeRange(types)); + return success(); + } + return failure(); +} + // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); @@ -102,6 +146,12 @@ customMultiEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); + pdlPattern.registerConstraintFunctionWithResults( + "op_constr_return_attr", customValueResultConstraint); + pdlPattern.registerConstraintFunctionWithResults( + "op_constr_return_type", customTypeResultConstraint); + pdlPattern.registerConstraintFunctionWithResults( + "op_constr_return_type_range", customTypeRangeResultConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); Index: mlir/test/mlir-pdll/Parser/constraint-failure.pdll =================================================================== --- mlir/test/mlir-pdll/Parser/constraint-failure.pdll +++ mlir/test/mlir-pdll/Parser/constraint-failure.pdll @@ -158,8 +158,3 @@ // CHECK: expected `;` after native declaration Constraint Foo() [{}] - -// ----- - -// CHECK: native Constraints currently do not support returning results -Constraint Foo() -> Op; Index: mlir/test/mlir-pdll/Parser/constraint.pdll =================================================================== --- mlir/test/mlir-pdll/Parser/constraint.pdll +++ mlir/test/mlir-pdll/Parser/constraint.pdll @@ -12,6 +12,14 @@ // ----- +// Test that native constraints support returning results. + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType +Constraint Foo() -> Attr; + +// ----- + // CHECK: Module // CHECK: `-UserConstraintDecl {{.*}} Name ResultType // CHECK: `Inputs`