diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1007,15 +1007,13 @@ ```c++ // TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static LogicalResult hasOneUseImpl(PDLValue pdlValue, ArrayAttr constantParams, - PatternRewriter &rewriter) { +static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) { Value value = pdlValue.cast(); return success(value.hasOneUse()); } -static LogicalResult hasSameElementTypeImpl( - ArrayRef pdlValues, ArrayAttr constantParams, - PatternRewriter &rewriter) { +static LogicalResult hasSameElementTypeImpl(ArrayRef pdlValues, + PatternRewriter &rewriter) { Value value1 = pdlValues[0].cast(); Value value2 = pdlValues[1].cast(); @@ -1310,8 +1308,8 @@ ```c++ // TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static void buildOpImpl(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void buildOpImpl(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { Value value = args[0].cast(); // insert special rewrite logic here. diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,33 +35,18 @@ let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. The constraint is permitted to accept any number of constant - valued parameters. + entities. Example: ```mlir - // Apply `myConstraint` to the entities defined by `input`, `attr`, and - // `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the - // constraint. - pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`. + pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); - let assemblyFormat = [{ - $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict - }]; - - let builders = [ - OpBuilder<(ins "StringRef":$name, CArg<"ValueRange", "{}">:$args, - CArg<"ArrayRef", "{}">:$params), [{ - build($_builder, $_state, $_builder.getStringAttr(name), args, - params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params)); - }]>, - ]; + let arguments = (ins StrAttr:$name, Variadic:$args); + let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; let hasVerifier = 1; } @@ -76,26 +61,22 @@ `pdl.apply_native_rewrite` operations apply a native C++ function, that has been registered externally with the consumer of PDL, to perform a rewrite and optionally return a number of values. The native function may accept any - number of arguments and constant attribute parameters. This operation is - used within a pdl.rewrite region to enable the interleaving of native - rewrite methods with other pdl constructs. + number of arguments. This operation is used within a pdl.rewrite region to enable + the interleaving of native rewrite methods with other pdl constructs. Example: ```mlir // Apply a native rewrite method that returns an attribute. - %ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute ``` ```c++ // The native rewrite as defined in C++: - static void myNativeFunc(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, + static void myNativeFunc(ArrayRef args, PatternRewriter &rewriter, PDLResultList &results) { Value arg0 = args[0].cast(); Value arg1 = args[1].cast(); - IntegerAttr param0 = constantParams[0].cast(); - StringAttr param1 = constantParams[1].cast(); // Just push back the first param attribute. results.push_back(param0); @@ -107,13 +88,10 @@ ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? - (`:` type($results)^)? attr-dict + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; let hasVerifier = 1; } @@ -588,16 +566,15 @@ rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions - similarly to `pdl.apply_native_rewrite`, and takes a set of constant - parameters and a set of additional positional values defined within the - matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the leading arguments. The root operation, - if provided, specifies the starting point in the pattern for the subgraph - isomorphism search. Pattern matching will proceed from this node downward - (towards the defining operation) or upward (towards the users) until all - the operations in the pattern have been matched. If the root is omitted, - the pdl_interp lowering will automatically select the best root of the - pdl.rewrite among all the operations in the pattern. + similarly to `pdl.apply_native_rewrite`, and takes a set of additional + positional values defined within the matcher as arguments. If the rewrite is + external, the root operation is passed to the native function as the leading + arguments. The root operation, if provided, specifies the starting point in + the pattern for the subgraph isomorphism search. Pattern matching will proceed + from this node downward (towards the defining operation) or upward + (towards the users) until all the operations in the pattern have been matched. + If the root is omitted, the pdl_interp lowering will automatically select + the best root of the pdl.rewrite among all the operations in the pattern. Example: @@ -623,12 +600,10 @@ let arguments = (ins Optional:$root, OptionalAttr:$name, - Variadic:$externalArgs, - OptionalAttr:$externalConstParams); + Variadic:$externalArgs); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - ($root^)? (`with` $name^ ($externalConstParams^)? - (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? ($body^)? attr-dict-with-keyword }]; diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -89,25 +89,21 @@ let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. The constraint may have any number of constant parameters. On - success, this operation branches to the true destination, otherwise the - false destination is taken. + values. On success, this operation branches to the true destination, + otherwise the false destination is taken. Example: ```mlir // Apply `myConstraint` to the entities defined by `input`, `attr`, and // `op`. - pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest + pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let assemblyFormat = [{ - $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->` - successors + $name `(` $args `:` type($args) `)` attr-dict `->` successors }]; } @@ -120,9 +116,8 @@ let description = [{ `pdl_interp.apply_rewrite` operations invoke an external rewriter that has been registered with the interpreter to perform the rewrite after a - successful match. The rewrite is passed a set of positional arguments, - and a set of constant parameters. The rewrite function may return any - number of results. + successful match. The rewrite is passed a set of positional arguments. The + rewrite function may return any number of results. Example: @@ -136,19 +131,12 @@ // Rewriter operating on the root operation along with additional arguments // from the matcher. pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value) - - // Rewriter operating on the root operation along with additional arguments - // and constant parameters. - pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value) ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? - (`:` type($results)^)? attr-dict + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -584,24 +584,16 @@ // PDLPatternModule /// A generic PDL pattern constraint function. This function applies a -/// constraint to a given set of opaque PDLValue entities. The second parameter -/// is a set of constant value parameters specified in Attribute form. Returns -/// success if the constraint successfully held, failure otherwise. -using PDLConstraintFunction = std::function, ArrayAttr, PatternRewriter &)>; -/// A native PDL rewrite function. This function performs a rewrite on the -/// given set of values and constant parameters. Any results from this rewrite -/// that should be passed back to PDL should be added to the provided result -/// list. This method is only invoked when the corresponding match was -/// successful. -using PDLRewriteFunction = std::function, ArrayAttr, PatternRewriter &, PDLResultList &)>; -/// A generic PDL pattern constraint function. This function applies a -/// constraint to a given opaque PDLValue entity. The second parameter is a set -/// of constant value parameters specified in Attribute form. Returns success if +/// constraint to a given set of opaque PDLValue entities. Returns success if /// the constraint successfully held, failure otherwise. -using PDLSingleEntityConstraintFunction = - std::function; +using PDLConstraintFunction = + std::function, PatternRewriter &)>; +/// A native PDL rewrite function. This function performs a rewrite on the +/// given set of values. Any results from this rewrite that should be passed +/// back to PDL should be added to the provided result list. This method is only +/// invoked when the corresponding match was successful. +using PDLRewriteFunction = + std::function, PatternRewriter &, PDLResultList &)>; /// This class contains all of the necessary data for a set of PDL patterns, or /// pattern rewrites specified in the form of the PDL dialect. This PDL module @@ -630,15 +622,14 @@ /// Register a single entity constraint function. template std::enable_if_t, - ArrayAttr, PatternRewriter &>::value> + PatternRewriter &>::value> registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { registerConstraintFunction( name, [constraintFn = std::forward(constraintFn)]( - ArrayRef values, ArrayAttr constantParams, - PatternRewriter &rewriter) { + ArrayRef values, PatternRewriter &rewriter) { assert(values.size() == 1 && "expected values to have a single entity"); - return constraintFn(values[0], constantParams, rewriter); + return constraintFn(values[0], rewriter); }); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -431,9 +431,8 @@ } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create( - loc, cstQuestion->getName(), args, cstQuestion->getParams(), success, - failure); + builder.create(loc, cstQuestion->getName(), + args, success, failure); break; } default: @@ -644,8 +643,7 @@ auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( - rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, - rewriter.externalConstParamsAttr()); + rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. for (Operation &rewriteOp : *rewriter.getBody()) { @@ -678,8 +676,8 @@ arguments.push_back(mapRewriteValue(argument)); auto interpOp = builder.create( rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), - arguments, rewriteOp.constParamsAttr()); - for (auto it : llvm::zip(rewriteOp.results(), interpOp.getResults())) + arguments); + for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) rewriteValues[std::get<0>(it)] = std::get<1>(it); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -445,10 +445,9 @@ /// Apply a parameterized constraint to multiple position values. struct ConstraintQuestion - : public PredicateBase< - ConstraintQuestion, Qualifier, - std::tuple, Attribute>, - Predicates::ConstraintQuestion> { + : public PredicateBase>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -457,17 +456,11 @@ /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } - /// Return the constant parameters of the constraint. - ArrayAttr getParams() const { - return std::get<2>(key).dyn_cast_or_null(); - } - /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key)), - std::get<2>(key)}); + alloc.copyInto(std::get<1>(key))}); } }; @@ -667,11 +660,9 @@ } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos, - Attribute params) { - return { - ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)), - TrueAnswer::get(uniquer)}; + Predicate getConstraint(StringRef name, ArrayRef pos) { + return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), + TrueAnswer::get(uniquer)}; } /// Create a predicate comparing a value with null. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -263,7 +263,6 @@ PredicateBuilder &builder, DenseMap &inputs) { OperandRange arguments = op.args(); - ArrayAttr parameters = op.constParamsAttr(); std::vector allPositions; allPositions.reserve(arguments.size()); @@ -274,7 +273,7 @@ Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); PredicateBuilder::Predicate pred = - builder.getConstraint(op.name(), allPositions, parameters); + builder.getConstraint(op.name(), allPositions); predList.emplace_back(pos, pred); } diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -425,10 +425,6 @@ return emitOpError() << "expected no external arguments when the " "rewrite is specified inline"; } - if (externalConstParams()) { - return emitOpError() << "expected no external constant parameters when " - "the rewrite is specified inline"; - } return success(); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -757,8 +757,7 @@ ByteCodeWriter &writer) { assert(constraintToMemIndex.count(op.getName()) && "expected index for constraint function"); - writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()], - op.getConstParamsAttr()); + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); writer.append(op.getSuccessors()); } @@ -766,8 +765,7 @@ ByteCodeWriter &writer) { assert(externalRewriterToMemIndex.count(op.getName()) && "expected index for rewrite function"); - writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()], - op.getConstParamsAttr()); + writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); ResultRange results = op.getResults(); @@ -1333,37 +1331,33 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; - ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(args, constParams, rewriter))); + selectJump(succeeded(constraintFn(args, rewriter))); } void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; - ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(args, constParams, rewriter, results); + rewriteFn(args, rewriter, results); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -190,7 +190,7 @@ // what we need as a frontend. os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, " - "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter" + "::mlir::PatternRewriter &rewriter" << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n"; const char *argumentInitStr = R"( diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -200,9 +200,9 @@ static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc) { if (isa(builder.getInsertionBlock()->getParentOp())) { - pdl::RewriteOp rewrite = builder.create( - loc, rootExpr, /*name=*/StringAttr(), - /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr()); + pdl::RewriteOp rewrite = + builder.create(loc, rootExpr, /*name=*/StringAttr(), + /*externalArgs=*/ValueRange()); builder.createBlock(&rewrite.body()); } } @@ -564,14 +564,8 @@ } else { resultTypes.push_back(genType(declResultType)); } - - // FIXME: We currently do not have a modeling for the "constant params" - // support PDL provides. We should either figure out a modeling for this, or - // refactor the support within PDL to be something a bit more reasonable for - // what we need as a frontend. - Operation *pdlOp = builder.create(loc, resultTypes, - decl->getName().getName(), inputs, - /*params=*/ArrayAttr()); + Operation *pdlOp = builder.create( + loc, resultTypes, decl->getName().getName(), inputs); return pdlOp->getResults(); } diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -59,14 +59,12 @@ def __init__(self, name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(name, args, params, loc=loc, ip=ip) + super().__init__(name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: @@ -76,14 +74,12 @@ results: Sequence[Type], name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(results, name, args, params, loc=loc, ip=ip) + super().__init__(results, name, args, loc=loc, ip=ip) class AttributeOp: @@ -236,15 +232,13 @@ root: Optional[Union[OpView, Operation, Value]] = None, name: Optional[Union[StringAttr, str]] = None, args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): root = root if root is None else _get_value(root) name = name if name is None else _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(root, name, args, params, loc=loc, ip=ip) + super().__init__(root, name, args, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -64,7 +64,7 @@ // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] - // CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]] + // CHECK: pdl_interp.apply_constraint "multi_constraint"(%[[INPUT]], %[[INPUT1]], %[[RESULT]] pdl.pattern : benefit(1) { %input0 = operand @@ -72,7 +72,7 @@ %root = operation(%input0, %input1 : !pdl.value, !pdl.value) %result0 = result 0 of %root - pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) + pdl.apply_native_constraint "multi_constraint"(%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) rewrite %root with "rewriter" } } @@ -393,11 +393,11 @@ // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] // CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]] // CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] - // CHECK: pdl_interp.apply_constraint "typeConstraint" [](%[[RESULT_TYPE]] + // CHECK: pdl_interp.apply_constraint "typeConstraint"(%[[RESULT_TYPE]] pdl.pattern : benefit(1) { %resultType = type - pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type) + pdl.apply_native_constraint "typeConstraint"(%resultType : !pdl.type) %root = operation -> (%resultType : !pdl.type) rewrite %root with "rewriter" } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -6,11 +6,11 @@ module @external { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value) - // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) + // CHECK: pdl_interp.apply_rewrite "rewriter"(%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) pdl.pattern : benefit(1) { %input = operand %root = operation "foo.op"(%input : !pdl.value) - rewrite %root with "rewriter"[true](%input : !pdl.value) + rewrite %root with "rewriter"(%input : !pdl.value) } } @@ -191,13 +191,13 @@ module @apply_native_rewrite { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) - // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type + // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor"(%[[ROOT]] : !pdl.operation) : !pdl.type // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type) pdl.pattern : benefit(1) { %type = type %root = operation "foo.op" -> (%type : !pdl.type) rewrite %root { - %newType = apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type + %newType = apply_native_rewrite "functor"(%root : !pdl.operation) : !pdl.type %newOp = operation "foo.op" -> (%newType : !pdl.type) } } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -8,7 +8,7 @@ %op = operation "foo.op" // expected-error@below {{expected at least one argument}} - "pdl.apply_native_constraint"() {name = "foo", params = []} : () -> () + "pdl.apply_native_constraint"() {name = "foo"} : () -> () rewrite %op with "rewriter" } @@ -22,7 +22,7 @@ %op = operation "foo.op" rewrite %op { // expected-error@below {{expected at least one argument}} - "pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> () + "pdl.apply_native_rewrite"() {name = "foo"} : () -> () } } @@ -264,19 +264,6 @@ // ----- -pdl.pattern : benefit(1) { - %op = operation "foo.op" - - // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} - "pdl.rewrite"(%op) ({ - ^bb1: - }) { - operand_segment_sizes = dense<[1,0]> : vector<2xi32>, - externalConstParams = []} : (!pdl.operation) -> () -} - -// ----- - pdl.pattern : benefit(1) { %op = operation "foo.op" diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -27,21 +27,6 @@ // ----- -pdl.pattern @rewrite_with_params : benefit(1) { - %root = operation - rewrite %root with "rewriter"["I am param"] -} - -// ----- - -pdl.pattern @rewrite_with_args_and_params : benefit(1) { - %input = operand - %root = operation(%input : !pdl.value) - rewrite %root with "rewriter"["I am param"](%input : !pdl.value) -} - -// ----- - pdl.pattern @rewrite_multi_root_optimal : benefit(2) { %input1 = operand %input2 = operand @@ -52,7 +37,7 @@ %op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type) %val2 = result 0 of %op2 %root2 = operation(%val1, %val2 : !pdl.value, !pdl.value) - rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation) + rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation) } // ----- @@ -67,7 +52,7 @@ %op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type) %val2 = result 0 of %op2 %root2 = operation(%val1, %val2 : !pdl.value, !pdl.value) - rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation) + rewrite %root1 with "rewriter"(%root2 : !pdl.operation) } // ----- diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -90,7 +90,7 @@ module @rewriters { pdl_interp.func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root - pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value) + pdl_interp.apply_rewrite "rewriter"(%root, %operand : !pdl.operation, !pdl.value) pdl_interp.finalize } } @@ -99,7 +99,7 @@ // CHECK-LABEL: test.apply_rewrite_1 // CHECK: %[[INPUT:.*]] = "test.op_input" // CHECK-NOT: "test.op" -// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]} +// CHECK: "test.success"(%[[INPUT]]) module @ir attributes { test.apply_rewrite_1 } { %input = "test.op_input"() : () -> i32 "test.op"(%input) : (i32) -> () diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -15,19 +15,16 @@ /// Custom constraint invoked from PDL. static LogicalResult customSingleEntityConstraint(PDLValue value, - ArrayAttr constantParams, PatternRewriter &rewriter) { Operation *rootOp = value.cast(); return success(rootOp->getName().getStringRef() == "test.op"); } static LogicalResult customMultiEntityConstraint(ArrayRef values, - ArrayAttr constantParams, PatternRewriter &rewriter) { - return customSingleEntityConstraint(values[1], constantParams, rewriter); + return customSingleEntityConstraint(values[1], rewriter); } static LogicalResult customMultiEntityVariadicConstraint(ArrayRef values, - ArrayAttr constantParams, PatternRewriter &rewriter) { if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) return failure(); @@ -39,32 +36,29 @@ } // Custom creator invoked from PDL. -static void customCreate(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void customCreate(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { results.push_back(rewriter.createOperation( OperationState(args[0].cast()->getLoc(), "test.success"))); } static void customVariadicResultCreate(ArrayRef args, - ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { Operation *root = args[0].cast(); results.push_back(root->getOperands()); results.push_back(root->getOperands().getTypes()); } -static void customCreateType(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, +static void customCreateType(ArrayRef args, PatternRewriter &rewriter, PDLResultList &results) { results.push_back(rewriter.getF32Type()); } /// Custom rewriter invoked from PDL. -static void customRewriter(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void customRewriter(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { Operation *root = args[0].cast(); OperationState successOpState(root->getLoc(), "test.success"); successOpState.addOperands(args[1].cast()); - successOpState.addAttribute("constantParams", constantParams); rewriter.createOperation(successOpState); rewriter.eraseOp(root); } diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -43,7 +43,7 @@ // Check the generation of native constraints and rewrites. -// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, // CHECK-SAME: ::mlir::PatternRewriter &rewriter) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: if (values[0]) @@ -69,7 +69,7 @@ // CHECK-NOT: TestUnusedCst -// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, // CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: ::mlir::Operation * op = {}; diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py --- a/mlir/test/python/dialects/pdl_ops.py +++ b/mlir/test/python/dialects/pdl_ops.py @@ -53,34 +53,6 @@ root = OperationOp(args=[input]) RewriteOp(root, "rewriter", args=[input]) -# CHECK: module { -# CHECK: pdl.pattern @rewrite_with_params : benefit(1) { -# CHECK: %0 = operation -# CHECK: rewrite %0 with "rewriter" ["I am param"] -# CHECK: } -# CHECK: } -@constructAndPrintInModule -def test_rewrite_with_params(): - pattern = PatternOp(1, "rewrite_with_params") - with InsertionPoint(pattern.body): - op = OperationOp() - RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")]) - -# CHECK: module { -# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) { -# CHECK: %0 = operand -# CHECK: %1 = operation(%0 : !pdl.value) -# CHECK: rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value) -# CHECK: } -# CHECK: } -@constructAndPrintInModule -def test_rewrite_with_args_and_params(): - pattern = PatternOp(1, "rewrite_with_args_and_params") - with InsertionPoint(pattern.body): - input = OperandOp() - root = OperationOp(args=[input]) - RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input]) - # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { # CHECK: %0 = operand @@ -92,7 +64,7 @@ # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) # CHECK: %7 = result 0 of %6 # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) -# CHECK: rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation) +# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation) # CHECK: } # CHECK: } @constructAndPrintInModule @@ -108,7 +80,7 @@ op2 = OperationOp(args=[input2], types=[ty]) val2 = ResultOp(op2, 0) root2 = OperationOp(args=[val1, val2]) - RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2]) + RewriteOp(name="rewriter", args=[root1, root2]) # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { @@ -121,7 +93,7 @@ # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) # CHECK: %7 = result 0 of %6 # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) -# CHECK: rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation) +# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation) # CHECK: } # CHECK: } @constructAndPrintInModule @@ -137,7 +109,7 @@ op2 = OperationOp(args=[input2], types=[ty]) val2 = ResultOp(op2, 0) root2 = OperationOp(args=[val1, val2]) - RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2]) + RewriteOp(root1, name="rewriter", args=[root2]) # CHECK: module { # CHECK: pdl.pattern @rewrite_add_body : benefit(1) { @@ -303,7 +275,7 @@ # CHECK: module { # CHECK: pdl.pattern : benefit(1) { # CHECK: %0 = type -# CHECK: apply_native_constraint "typeConstraint" [](%0 : !pdl.type) +# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type) # CHECK: %1 = operation -> (%0 : !pdl.type) # CHECK: rewrite %1 with "rewrite" # CHECK: } @@ -313,6 +285,6 @@ pattern = PatternOp(1) with InsertionPoint(pattern.body): resultType = TypeOp() - ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[]) + ApplyNativeConstraintOp("typeConstraint", args=[resultType]) root = OperationOp(types=[resultType]) RewriteOp(root, name="rewrite")