diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -29,17 +29,17 @@ } //===----------------------------------------------------------------------===// -// pdl::ApplyConstraintOp +// pdl::ApplyNativeConstraintOp //===----------------------------------------------------------------------===// -def PDL_ApplyConstraintOp - : PDL_Op<"apply_constraint", [HasParent<"pdl::PatternOp">]> { - let summary = "Apply a generic constraint to a set of provided entities"; +def PDL_ApplyNativeConstraintOp + : PDL_Op<"apply_native_constraint", [HasParent<"pdl::PatternOp">]> { + let summary = "Apply a native constraint to a set of provided entities"; let description = [{ - `apply_constraint` operations apply a generic constraint, that has been - registered externally with the consumer of PDL, to a given set of entities. - The constraint is permitted to accept any number of constant valued - parameters. + `pdl.apply_native_constraint` operations apply a native C++ constraint, that + has been registered externally with the consumer of PDL, to a given set of + entities. The constraint is permitted to accept any number of constant + valued parameters. Example: @@ -47,7 +47,7 @@ // Apply `myConstraint` to the entities defined by `input`, `attr`, and // `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the // constraint. - pdl.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) ``` }]; @@ -67,6 +67,39 @@ ]; } +//===----------------------------------------------------------------------===// +// pdl::ApplyNativeRewriteOp +//===----------------------------------------------------------------------===// + +def PDL_ApplyNativeRewriteOp + : PDL_Op<"apply_native_rewrite", [HasParent<"pdl::RewriteOp">]> { + let summary = "Apply a native rewrite method inside of pdl.rewrite region"; + let description = [{ + `pdl.apply_native_rewrite` operations apply a native C++ function, that has + been registered externally with the consumer of PDL, to perform a rewrite + and optionally return a number of values. The native function may accept any + number of arguments and constant attribute parameters. This operation is + used within a pdl.rewrite region to enable the interleaving of native + rewrite methods with other pdl constructs. + + Example: + + ```mlir + // Apply a native rewrite method that returns an attribute. + %ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + ``` + }]; + + let arguments = (ins StrAttr:$name, + Variadic:$args, + OptionalAttr:$constParams); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($results) + attr-dict + }]; +} + //===----------------------------------------------------------------------===// // pdl::AttributeOp //===----------------------------------------------------------------------===// @@ -113,39 +146,6 @@ ]; } -//===----------------------------------------------------------------------===// -// pdl::CreateNativeOp -//===----------------------------------------------------------------------===// - -def PDL_CreateNativeOp - : PDL_Op<"create_native", [HasParent<"pdl::RewriteOp">]> { - let summary = "Call a native creation method to construct an `Attribute`, " - "`Operation`, `Type`, or `Value`"; - let description = [{ - `pdl.create_native` operations invoke a native C++ function, that has been - registered externally with the consumer of PDL, to create an `Attribute`, - `Operation`, `Type`, or `Value`. The native function must produce a value - of the specified return type, and may accept any number of positional - arguments and constant attribute parameters. - - Example: - - ```mlir - %ret = pdl.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute - ``` - }]; - - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); - let results = (outs PDL_AnyType:$result); - let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result) - attr-dict - }]; - let verifier = ?; -} - //===----------------------------------------------------------------------===// // pdl::EraseOp //===----------------------------------------------------------------------===// @@ -233,9 +233,10 @@ `pdl.rewrite`, all of the result types must be "inferable". This means that the type must be attributable to either a constant type value or the result type of another entity, such as an attribute, the result of a - `createNative`, or the result type of another operation. If the result type - value does not meet any of these criteria, the operation must provide the - `InferTypeOpInterface` to ensure that the result types can be inferred. + `apply_native_rewrite`, or the result type of another operation. If the + result type value does not meet any of these criteria, the operation must + override the `InferTypeOpInterface` to ensure that the result types can be + inferred. Example: @@ -408,13 +409,14 @@ let summary = "Specify the rewrite of a matched pattern"; let description = [{ `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify - the rewrite of a `pdl.pattern`, on the specified root operation. The + the main rewrite of a `pdl.pattern`, on the specified root operation. The rewrite is specified either via a string name (`name`) to an external rewrite function, or via the region body. The rewrite region, if specified, must contain a single block and terminate via the `pdl.rewrite_end` operation. If the rewrite is external, it also takes a set of constant parameters and a set of additional positional values defined within the - matcher as arguments. + matcher as arguments. If the rewrite is external, the root operation is + passed to the native function as the first argument. Example: diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -130,32 +130,35 @@ let description = [{ `pdl_interp.apply_rewrite` operations invoke an external rewriter that has been registered with the interpreter to perform the rewrite after a - successful match. The rewrite is passed the root operation being matched, a - set of additional positional arguments generated within the matcher, and a - set of constant parameters. + successful match. The rewrite is passed a set of positional arguments, + and a set of constant parameters. The rewrite function may return any + number of results. Example: ```mlir // Rewriter operating solely on the root operation. - pdl_interp.apply_rewrite "rewriter" on %root + pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation) + + // Rewriter operating solely on the root operation and return an attribute. + %attr = pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation) : !pdl.attribute // Rewriter operating on the root operation along with additional arguments // from the matcher. - pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root + pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value) // Rewriter operating on the root operation along with additional arguments // and constant parameters. - pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root + pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value) ``` }]; let arguments = (ins StrAttr:$name, - PDL_Operation:$root, Variadic:$args, OptionalAttr:$constParams); + let results = (outs Variadic:$results); let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root - attr-dict + $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? + (`:` type($results)^)? attr-dict }]; } @@ -351,38 +354,6 @@ }]>]; } -//===----------------------------------------------------------------------===// -// pdl_interp::CreateNativeOp -//===----------------------------------------------------------------------===// - -def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> { - let summary = "Call a native creation method to construct an `Attribute`, " - "`Operation`, `Type`, or `Value`"; - let description = [{ - `pdl_interp.create_native` operations invoke a native C++ function, that has - been registered externally with the consumer of PDL, to create an - `Attribute`, `Operation`, `Type`, or `Value`. The native function must - produce a value of the specified return type, and may accept any number of - positional arguments and constant attribute parameters. - - Example: - - ```mlir - %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute - ``` - }]; - - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); - let results = (outs PDL_AnyType:$result); - let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result) - attr-dict - }]; - let verifier = ?; -} - //===----------------------------------------------------------------------===// // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -302,6 +302,33 @@ return os; } +//===----------------------------------------------------------------------===// +// PDLResultList + +/// The class represents a list of PDL results, returned by a native rewrite +/// method. It provides the mechanism with which to pass PDLValues back to the +/// PDL bytecode. +class PDLResultList { +public: + /// Push a new Attribute value onto the result list. + void push_back(Attribute value) { results.push_back(value); } + + /// Push a new Operation onto the result list. + void push_back(Operation *value) { results.push_back(value); } + + /// Push a new Type onto the result list. + void push_back(Type value) { results.push_back(value); } + + /// Push a new Value onto the result list. + void push_back(Value value) { results.push_back(value); } + +protected: + PDLResultList() = default; + + /// The PDL results held by this list. + SmallVector results; +}; + //===----------------------------------------------------------------------===// // PDLPatternModule @@ -311,16 +338,13 @@ /// success if the constraint successfully held, failure otherwise. using PDLConstraintFunction = std::function, ArrayAttr, PatternRewriter &)>; -/// A native PDL creation function. This function creates a new PDLValue given -/// a set of existing PDL values, a set of constant parameters specified in -/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue. -using PDLCreateFunction = - std::function, ArrayAttr, PatternRewriter &)>; -/// A native PDL rewrite function. This function rewrites the given root -/// operation using the provided PatternRewriter. This method is only invoked -/// when the corresponding match was successful. -using PDLRewriteFunction = std::function, - ArrayAttr, PatternRewriter &)>; +/// A native PDL rewrite function. This function performs a rewrite on the +/// given set of values and constant parameters. Any results from this rewrite +/// that should be passed back to PDL should be added to the provided result +/// list. This method is only invoked when the corresponding match was +/// successful. +using PDLRewriteFunction = std::function, ArrayAttr, PatternRewriter &, PDLResultList &)>; /// A generic PDL pattern constraint function. This function applies a /// constraint to a given opaque PDLValue entity. The second parameter is a set /// of constant value parameters specified in Attribute form. Returns success if @@ -367,9 +391,6 @@ }); } - /// Register a creation function. - void registerCreateFunction(StringRef name, PDLCreateFunction createFn); - /// Register a rewrite function. void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); @@ -380,13 +401,6 @@ llvm::StringMap takeConstraintFunctions() { return constraintFunctions; } - /// Return the set of the registered create functions. - const llvm::StringMap &getCreateFunctions() const { - return createFunctions; - } - llvm::StringMap takeCreateFunctions() { - return createFunctions; - } /// Return the set of the registered rewrite functions. const llvm::StringMap &getRewriteFunctions() const { return rewriteFunctions; @@ -399,7 +413,6 @@ void clear() { pdlModule = nullptr; constraintFunctions.clear(); - createFunctions.clear(); rewriteFunctions.clear(); } @@ -409,7 +422,6 @@ /// The external functions referenced from within the PDL module. llvm::StringMap constraintFunctions; - llvm::StringMap createFunctions; llvm::StringMap rewriteFunctions; }; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -70,6 +70,9 @@ SmallVectorImpl &usedMatchValues); /// Generate the rewriter code for the given operation. + void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, + DenseMap &rewriteValues, + function_ref mapRewriteValue); void generateRewriter(pdl::AttributeOp attrOp, DenseMap &rewriteValues, function_ref mapRewriteValue); @@ -79,9 +82,6 @@ void generateRewriter(pdl::OperationOp operationOp, DenseMap &rewriteValues, function_ref mapRewriteValue); - void generateRewriter(pdl::CreateNativeOp createNativeOp, - DenseMap &rewriteValues, - function_ref mapRewriteValue); void generateRewriter(pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue); @@ -449,17 +449,17 @@ // method. pdl::RewriteOp rewriter = pattern.getRewriter(); if (StringAttr rewriteName = rewriter.nameAttr()) { - Value root = mapRewriteValue(rewriter.root()); - SmallVector args = llvm::to_vector<4>( - llvm::map_range(rewriter.externalArgs(), mapRewriteValue)); + auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); + SmallVector args(1, mapRewriteValue(rewriter.root())); + args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( - rewriter.getLoc(), rewriteName, root, args, + rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, rewriter.externalConstParamsAttr()); } else { // Otherwise this is a dag rewriter defined using PDL operations. for (Operation &rewriteOp : *rewriter.getBody()) { llvm::TypeSwitch(&rewriteOp) - .Case( [&](auto op) { this->generateRewriter(op, rewriteValues, mapRewriteValue); @@ -478,6 +478,19 @@ builder.getSymbolRefAttr(rewriterFunc)); } +void PatternLowering::generateRewriter( + pdl::ApplyNativeRewriteOp rewriteOp, DenseMap &rewriteValues, + function_ref mapRewriteValue) { + SmallVector arguments; + for (Value argument : rewriteOp.args()) + arguments.push_back(mapRewriteValue(argument)); + auto interpOp = builder.create( + rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), + arguments, rewriteOp.constParamsAttr()); + for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) + rewriteValues[std::get<0>(it)] = std::get<1>(it); +} + void PatternLowering::generateRewriter( pdl::AttributeOp attrOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { @@ -527,18 +540,6 @@ } } -void PatternLowering::generateRewriter( - pdl::CreateNativeOp createNativeOp, DenseMap &rewriteValues, - function_ref mapRewriteValue) { - SmallVector arguments; - for (Value argument : createNativeOp.args()) - arguments.push_back(mapRewriteValue(argument)); - Value result = builder.create( - createNativeOp.getLoc(), createNativeOp.result().getType(), - createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr()); - rewriteValues[createNativeOp] = result; -} - void PatternLowering::generateRewriter( pdl::ReplaceOp replaceOp, DenseMap &rewriteValues, function_ref mapRewriteValue) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -153,7 +153,7 @@ /// Collect all of the predicates related to constraints within the given /// pattern operation. -static void getConstraintPredicates(pdl::ApplyConstraintOp op, +static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, std::vector &predList, PredicateBuilder &builder, DenseMap &inputs) { @@ -192,7 +192,7 @@ PredicateBuilder &builder, DenseMap &inputs) { for (Operation &op : pattern.body().getOps()) { - if (auto constraintOp = dyn_cast(&op)) + if (auto constraintOp = dyn_cast(&op)) getConstraintPredicates(constraintOp, predList, builder, inputs); else if (auto resultOp = dyn_cast(&op)) getResultPredicates(resultOp, predList, builder, inputs); diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -67,15 +67,25 @@ } //===----------------------------------------------------------------------===// -// pdl::ApplyConstraintOp +// pdl::ApplyNativeConstraintOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ApplyConstraintOp op) { +static LogicalResult verify(ApplyNativeConstraintOp op) { if (op.getNumOperands() == 0) return op.emitOpError("expected at least one argument"); return success(); } +//===----------------------------------------------------------------------===// +// pdl::ApplyNativeRewriteOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ApplyNativeRewriteOp op) { + if (op.getNumOperands() == 0 && op.getNumResults() == 0) + return op.emitOpError("expected at least one argument or result"); + return success(); +} + //===----------------------------------------------------------------------===// // pdl::AttributeOp //===----------------------------------------------------------------------===// @@ -165,9 +175,9 @@ Operation *resultTypeOp = it.value().getDefiningOp(); assert(resultTypeOp && "expected valid result type operation"); - // If the op was defined by a `create_native`, it is guaranteed to be + // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be // usable. - if (isa(resultTypeOp)) + if (isa(resultTypeOp)) continue; // If the type is already constrained, there is nothing to do. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -102,7 +102,6 @@ // Steal the other state if we have no patterns. if (!pdlModule) { constraintFunctions = std::move(other.constraintFunctions); - createFunctions = std::move(other.createFunctions); rewriteFunctions = std::move(other.rewriteFunctions); pdlModule = std::move(other.pdlModule); return; @@ -110,8 +109,6 @@ // Steal the functions of the other module. for (auto &it : constraintFunctions) registerConstraintFunction(it.first(), std::move(it.second)); - for (auto &it : createFunctions) - registerCreateFunction(it.first(), std::move(it.second)); for (auto &it : rewriteFunctions) registerRewriteFunction(it.first(), std::move(it.second)); @@ -132,13 +129,7 @@ assert(it.second && "constraint with the given name has already been registered"); } -void PDLPatternModule::registerCreateFunction(StringRef name, - PDLCreateFunction createFn) { - auto it = createFunctions.try_emplace(name, std::move(createFn)); - (void)it; - assert(it.second && "native create function with the given name has " - "already been registered"); -} + void PDLPatternModule::registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) { auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -114,7 +114,6 @@ /// the PDL interpreter dialect. PDLByteCode(ModuleOp module, llvm::StringMap constraintFns, - llvm::StringMap createFns, llvm::StringMap rewriteFns); /// Return the patterns held by the bytecode. @@ -160,7 +159,6 @@ /// A set of user defined functions invoked via PDL. std::vector constraintFunctions; - std::vector createFunctions; std::vector rewriteFunctions; /// The maximum memory index used by a value. diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -80,8 +80,6 @@ CheckOperationName, /// Compare the result count of an operation with a constant. CheckResultCount, - /// Invoke a native creation method. - CreateNative, /// Create an operation. CreateOperation, /// Erase an operation. @@ -148,15 +146,12 @@ SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, llvm::StringMap &constraintFns, - llvm::StringMap &createFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); - for (auto it : llvm::enumerate(createFns)) - nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); } @@ -203,7 +198,6 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); - void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); @@ -235,10 +229,6 @@ /// in the bytecode registry. llvm::StringMap constraintToMemIndex; - /// Mapping from the name of an externally registered creation method to its - /// index in the bytecode registry. - llvm::StringMap nativeCreateToMemIndex; - /// Mapping from rewriter function name to the bytecode address of the /// rewriter function in byte. llvm::StringMap rewriterToAddr; @@ -492,16 +482,16 @@ pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, - pdl_interp::CreateTypeOp, pdl_interp::EraseOp, - pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, - pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, - pdl_interp::GetOperandOp, pdl_interp::GetResultOp, - pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, - pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, - pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, - pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, - pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( + pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, + pdl_interp::EraseOp, pdl_interp::FinalizeOp, + pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, + pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, + pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, + pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp, + pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, + pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, + pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, + pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -522,8 +512,16 @@ assert(externalRewriterToMemIndex.count(op.name()) && "expected index for rewrite function"); writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], - op.constParamsAttr(), op.root()); + op.constParamsAttr()); writer.appendPDLValueList(op.args()); + +#ifndef NDEBUG + // In debug mode we also append the number of results so that we can assert + // that the native creation function gave us the correct number of results. + writer.append(ByteCodeField(op.results().size())); +#endif + for (Value result : op.results()) + writer.append(result); } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); @@ -559,14 +557,6 @@ // Simply repoint the memory index of the result to the constant. getMemIndex(op.attribute()) = getMemIndex(op.value()); } -void Generator::generate(pdl_interp::CreateNativeOp op, - ByteCodeWriter &writer) { - assert(nativeCreateToMemIndex.count(op.name()) && - "expected index for creation function"); - writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], - op.result(), op.constParamsAttr()); - writer.appendPDLValueList(op.args()); -} void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateOperation, op.operation(), @@ -678,18 +668,15 @@ PDLByteCode::PDLByteCode(ModuleOp module, llvm::StringMap constraintFns, - llvm::StringMap createFns, llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - constraintFns, createFns, rewriteFns); + constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. for (auto &it : constraintFns) constraintFunctions.push_back(std::move(it.second)); - for (auto &it : createFns) - createFunctions.push_back(std::move(it.second)); for (auto &it : rewriteFns) rewriteFunctions.push_back(std::move(it.second)); } @@ -717,12 +704,11 @@ ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, - ArrayRef createFunctions, ArrayRef rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), - createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} + rewriteFunctions(rewriteFunctions) {} /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching @@ -740,7 +726,6 @@ void executeCheckOperandCount(); void executeCheckOperationName(); void executeCheckResultCount(); - void executeCreateNative(PatternRewriter &rewriter); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeEraseOp(PatternRewriter &rewriter); @@ -866,9 +851,17 @@ ArrayRef currentPatternBenefits; ArrayRef patterns; ArrayRef constraintFunctions; - ArrayRef createFunctions; ArrayRef rewriteFunctions; }; + +/// This class is an instantiation of the PDLResultList that provides access to +/// the returned results. This API is not on `PDLResultList` to avoid +/// overexposing access to information specific solely to the ByteCode. +class ByteCodeRewriteResultList : public PDLResultList { +public: + /// Return the list of PDL results. + MutableArrayRef getResults() { return results; } +}; } // end anonymous namespace void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { @@ -892,18 +885,29 @@ LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; ArrayAttr constParams = read(); - Operation *root = read(); SmallVector args; readList(args); LLVM_DEBUG({ - llvm::dbgs() << " * Root: " << *root << "\n * Arguments: "; + llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); - - // Invoke the native rewrite function. - rewriteFn(root, args, constParams, rewriter); + ByteCodeRewriteResultList results; + rewriteFn(args, constParams, rewriter, results); + + // Store the results in the bytecode memory. +#ifndef NDEBUG + ByteCodeField expectedNumberOfResults = read(); + assert(results.getResults().size() == expectedNumberOfResults && + "native PDL rewrite function returned unexpected number of results"); +#endif + + // Store the results in the bytecode memory. + for (PDLValue &result : results.getResults()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + memory[read()] = result.getAsOpaquePointer(); + } } void ByteCodeExecutor::executeAreEqual() { @@ -950,26 +954,6 @@ selectJump(op->getNumResults() == expectedCount); } -void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); - const PDLCreateFunction &createFn = createFunctions[read()]; - ByteCodeField resultIndex = read(); - ArrayAttr constParams = read(); - SmallVector args; - readList(args); - - LLVM_DEBUG({ - llvm::dbgs() << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; - }); - - PDLValue result = createFn(args, constParams, rewriter); - memory[resultIndex] = result.getAsOpaquePointer(); - - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); -} - void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc) { LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); @@ -1246,9 +1230,6 @@ case CheckResultCount: executeCheckResultCount(); break; - case CreateNative: - executeCreateNative(rewriter); - break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; @@ -1338,8 +1319,7 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, matcherByteCode, state.currentPatternBenefits, - patterns, constraintFunctions, createFunctions, - rewriteFunctions); + patterns, constraintFunctions, rewriteFunctions); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1356,9 +1336,9 @@ // memory buffer. llvm::copy(match.values, state.memory.begin()); - ByteCodeExecutor executor( - &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, - constraintFunctions, createFunctions, rewriteFunctions); + ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()], + state.memory, uniquedData, rewriterByteCode, + state.currentPatternBenefits, patterns, + constraintFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); } diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -70,7 +70,7 @@ // Generate the pdl bytecode. impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConstraintFunctions(), - pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); + pdlPatterns.takeRewriteFunctions()); } FrozenRewritePatternList::~FrozenRewritePatternList() {} diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -24,7 +24,7 @@ // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation) - // CHECK: pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]] + // CHECK: pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]] // CHECK: pdl_interp.finalize pdl.pattern : benefit(1) { %root = pdl.operation "foo.op"() @@ -72,7 +72,7 @@ %root = pdl.operation(%input0, %input1) %result0 = pdl.result 0 of %root - pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) + pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) pdl.rewrite %root with "rewriter" } } @@ -194,7 +194,7 @@ pdl.pattern : benefit(1) { %resultType = pdl.type - pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type) + pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type) %root = pdl.operation -> %resultType pdl.rewrite %root with "rewriter" } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -6,7 +6,7 @@ module @external { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value) - // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]] + // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) pdl.pattern : benefit(1) { %input = pdl.operand %root = pdl.operation "foo.op"(%input) @@ -170,17 +170,17 @@ // ----- -// CHECK-LABEL: module @create_native -module @create_native { +// CHECK-LABEL: module @apply_native_rewrite +module @apply_native_rewrite { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) - // CHECK: %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type + // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] pdl.pattern : benefit(1) { %type = pdl.type %root = pdl.operation "foo.op" -> %type pdl.rewrite %root { - %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type + %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type %newOp = pdl.operation "foo.op" -> %newType pdl.replace %root with %newOp } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -1,19 +1,33 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics //===----------------------------------------------------------------------===// -// pdl::ApplyConstraintOp +// pdl::ApplyNativeConstraintOp //===----------------------------------------------------------------------===// pdl.pattern : benefit(1) { %op = pdl.operation "foo.op" // expected-error@below {{expected at least one argument}} - "pdl.apply_constraint"() {name = "foo", params = []} : () -> () + "pdl.apply_native_constraint"() {name = "foo", params = []} : () -> () pdl.rewrite %op with "rewriter" } // ----- +//===----------------------------------------------------------------------===// +// pdl::ApplyNativeRewriteOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op" + pdl.rewrite %op { + // expected-error@below {{expected at least one argument}} + "pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> () + } +} + +// ----- + //===----------------------------------------------------------------------===// // pdl::AttributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -58,7 +58,7 @@ module @rewriters { func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root - pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root + pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value) pdl_interp.finalize } } @@ -72,6 +72,35 @@ %input = "test.op_input"() : () -> i32 "test.op"(%input) : (i32) -> () } + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.apply_rewrite "creator"(%root : !pdl.operation) : !pdl.operation + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_rewrite_2 +// CHECK: "test.success" +module @ir attributes { test.apply_rewrite_2 } { + "test.op"() : () -> () +} + // ----- //===----------------------------------------------------------------------===// @@ -317,38 +346,6 @@ // Fully tested within the tests for other operations. -//===----------------------------------------------------------------------===// -// pdl_interp::CreateNativeOp -//===----------------------------------------------------------------------===// - -// ----- - -module @patterns { - func @matcher(%root : !pdl.operation) { - pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end - - ^pat: - pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end - - ^end: - pdl_interp.finalize - } - - module @rewriters { - func @success(%root : !pdl.operation) { - %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation - pdl_interp.erase %root - pdl_interp.finalize - } - } -} - -// CHECK-LABEL: test.create_native_1 -// CHECK: "test.success" -module @ir attributes { test.create_native_1 } { - "test.op"() : () -> () -} - //===----------------------------------------------------------------------===// // pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -26,18 +26,18 @@ } // Custom creator invoked from PDL. -static PDLValue customCreate(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter) { - return rewriter.createOperation( - OperationState(args[0].cast()->getLoc(), "test.success")); +static void customCreate(ArrayRef args, ArrayAttr constantParams, + PatternRewriter &rewriter, PDLResultList &results) { + results.push_back(rewriter.createOperation( + OperationState(args[0].cast()->getLoc(), "test.success"))); } /// Custom rewriter invoked from PDL. -static void customRewriter(Operation *root, ArrayRef args, - ArrayAttr constantParams, - PatternRewriter &rewriter) { +static void customRewriter(ArrayRef args, ArrayAttr constantParams, + PatternRewriter &rewriter, PDLResultList &results) { + Operation *root = args[0].cast(); OperationState successOpState(root->getLoc(), "test.success"); - successOpState.addOperands(args[0].cast()); + successOpState.addOperands(args[1].cast()); successOpState.addAttribute("constantParams", constantParams); rewriter.createOperation(successOpState); rewriter.eraseOp(root); @@ -63,7 +63,7 @@ customMultiEntityConstraint); pdlPattern.registerConstraintFunction("single_entity_constraint", customSingleEntityConstraint); - pdlPattern.registerCreateFunction("creator", customCreate); + pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("rewriter", customRewriter); OwningRewritePatternList patternList(std::move(pdlPattern));