diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -60,7 +60,7 @@ let builders = [ OpBuilder<(ins "StringRef":$name, CArg<"ValueRange", "{}">:$args, - CArg<"ArrayRef", "{}">:$params), [{ + CArg<"ArrayRef", "{}">:$params), [{ build($_builder, $_state, $_builder.getStringAttr(name), args, params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params)); }]>, @@ -196,9 +196,9 @@ let description = [{ `pdl.operand` operations capture external operand edges into an operation node that originate from operations or block arguments not otherwise - specified within the pattern (e.g. via `pdl.result`). These operations - define individual operands of a given operation. A `pdl.operand` may - partially constrain an operand by specifying an expected value type + specified within the pattern (i.e. via `pdl.result` or `pdl.results`). These + operations define individual operands of a given operation. A `pdl.operand` + may partially constrain an operand by specifying an expected value type (via a `pdl.type` operation). Example: @@ -224,6 +224,44 @@ ]; } +//===----------------------------------------------------------------------===// +// pdl::OperandsOp +//===----------------------------------------------------------------------===// + +def PDL_OperandsOp : PDL_Op<"operands", [HasParent<"pdl::PatternOp">]> { + let summary = "Define a range of input operands in a pattern"; + let description = [{ + `pdl.operands` operations capture external operand range edges into an + operation node that originate from operations or block arguments not + otherwise specified within the pattern (i.e. via `pdl.result` or + `pdl.results`). These operations define groups of input operands into a + given operation. A `pdl.operands` may partially constrain a set of input + operands by specifying expected value types (via `pdl.types` operations). + + Example: + + ```mlir + // Define a range of input operands: + %operands = pdl.operands + + // Define a range of input operands with expected types: + %types = pdl.types : [i32, i64, i32] + %typed_operands = pdl.operands : %types + ``` + }]; + + let arguments = (ins Optional>:$type); + let results = (outs PDL_RangeOf:$val); + let assemblyFormat = "(`:` $type^)? attr-dict"; + + let builders = [ + OpBuilder<(ins), [{ + build($_builder, $_state, RangeType::get($_builder.getType()), + Value()); + }]>, + ]; +} + //===----------------------------------------------------------------------===// // pdl::OperationOp //===----------------------------------------------------------------------===// @@ -245,6 +283,14 @@ a handle to the operation itself. Handles to the results of the operation can be extracted via `pdl.result`. + Example: + + ```mlir + // Define an instance of a `foo.op` operation. + %op = pdl.operation "foo.op"(%arg0, %arg1 : !pdl.value, !pdl.value) + {"attrA" = %attr0} -> (%type, %type : !pdl.type, !pdl.type) + ``` + When used within a matching context, the name of the operation may be omitted. @@ -257,24 +303,78 @@ override the `InferTypeOpInterface` to ensure that the result types can be inferred. - Example: + The operands of the operation are interpreted in the following ways: + + 1) A single !pdl.range: + + In this case, the single range is treated as all of the operands of the + operation. ```mlir - // Define an instance of a `foo.op` operation. - %op = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type + // Define an instance with single range of operands. + %op = pdl.operation "std.return"(%allArgs : !pdl.range) + ``` + + 2) A variadic number of either !pdl.value or !pdl.range: + + In this case, the inputs are expected to correspond with the operand groups + defined on the operation in ODS. + + ```tablgen + // Given the following operation definition in ODS: + def MyIndirectCallOp { + let results = (outs FunctionType:$call, Variadic:$args); + } + ``` + + ```mlir + // We can match the operands as so: + %op = pdl.operation "my.indirect_call"(%call, %args : !pdl.value, !pdl.range) + ``` + + The results of the operation are interpreted in the following ways: + + 1) A single !pdl.range: + + In this case, the single range is treated as all of the result types of the + operation. + + ```mlir + // Define an instance with single range of types. + %allResultTypes = pdl.types + %op = pdl.operation "unrealized_conversion_cast" -> (%allResultTypes : !pdl.types) + ``` + + 2) A variadic number of either !pdl.type or !pdl.range: + + In this case, the inputs are expected to correspond with the result groups + defined on the operation in ODS. + + ```tablgen + // Given the following operation definition in ODS: + def MyOp { + let results = (outs SomeType:$result, Variadic:$otherResults); + } + ``` + + ```mlir + // We can match the results as so: + %result = pdl.type + %otherResults = pdl.types + %op = pdl.operation "foo.op" -> (%result, %otherResults : !pdl.type, !pdl.range) ``` }]; let arguments = (ins OptionalAttr:$name, - Variadic:$operands, + Variadic>:$operands, Variadic:$attributes, StrArrayAttr:$attributeNames, - Variadic:$types); + Variadic>:$types); let results = (outs PDL_Operation:$op); let assemblyFormat = [{ - ($name^)? (`(` $operands^ `)`)? + ($name^)? (`(` $operands^ `:` type($operands) `)`)? custom($attributes, $attributeNames) - (`->` $types^)? attr-dict + (`->` `(` $types^ `:` type($types) `)`)? attr-dict }]; let builders = [ @@ -378,7 +478,10 @@ ```mlir // Replace root node with 2 values: - pdl.replace %root with (%val0, %val1) + pdl.replace %root with (%val0, %val1 : !pdl.value, !pdl.value) + + // Replace root node with a range of values: + pdl.replace %root with (%vals : !pdl.range) // Replace root with another operation: pdl.replace %root with %otherOp @@ -386,9 +489,10 @@ }]; let arguments = (ins PDL_Operation:$operation, Optional:$replOperation, - Variadic:$replValues); + Variadic>:$replValues); let assemblyFormat = [{ - $operation `with` (`(` $replValues^ `)`)? ($replOperation^)? attr-dict + $operation `with` (`(` $replValues^ `:` type($replValues) `)`)? + ($replOperation^)? attr-dict }]; } @@ -409,13 +513,13 @@ ```mlir // Extract a result: %operation = pdl.operation ... - %result = pdl.result 1 of %operation + %pdl_result = pdl.result 1 of %operation // Imagine the following IR being matched: %result_0, %result_1 = foo.op ... // If the example pattern snippet above were matching against `foo.op` in - // the IR snippted, `%result` would correspond to `%result_1`. + // the IR snippet, `%pdl_result` would correspond to `%result_1`. ``` }]; @@ -425,6 +529,48 @@ let verifier = ?; } +//===----------------------------------------------------------------------===// +// pdl::ResultsOp +//===----------------------------------------------------------------------===// + +def PDL_ResultsOp : PDL_Op<"results"> { + let summary = "Extract a result group from an operation"; + let description = [{ + `pdl.results` operations extract a result group from an operation within a + pattern or rewrite region. If an index is provided, this operation extracts + a result group as defined by the ODS definition of the operation. In this + case the result of this operation may be either a single `pdl.value` or + a `pdl.range`, depending on the constraint of the result in ODS. If + no index is provided, this operation extracts the full result range of the + operation. + + Example: + + ```mlir + // Extract all of the results of an operation: + %operation = pdl.operation ... + %results = pdl.results of %operation + + // Extract the results in the first result group of an operation, which is + // variadic: + %operation = pdl.operation ... + %results = pdl.results 0 of %operation -> !pdl.range + + // Extract the results in the second result group of an operation, which is + // not variadic: + %operation = pdl.operation ... + %results = pdl.results 1 of %operation -> !pdl.value + ``` + }]; + + let arguments = (ins PDL_Operation:$parent, OptionalAttr:$index); + let results = (outs PDL_InstOrRangeOf:$val); + let assemblyFormat = [{ + ($index^)? `of` $parent custom(ref($index), type($val)) + attr-dict + }]; +} + //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// @@ -489,7 +635,7 @@ def PDL_TypeOp : PDL_Op<"type"> { let summary = "Define a type handle within a pattern"; let description = [{ - `pdl.type` operations capture result type constraints of an `Attributes`, + `pdl.type` operations capture result type constraints of `Attributes`, `Values`, and `Operations`. Instances of this operation define, and partially constrain, results types of a given entity. A `pdl.type` may partially constrain the result by specifying a constant `Type`. @@ -498,23 +644,44 @@ ```mlir // Define a type: - %attr = pdl.type + %type = pdl.type // Define a type with a constant value: - %attr = pdl.type : i32 + %type = pdl.type : i32 ``` }]; let arguments = (ins OptionalAttr:$type); let results = (outs PDL_Type:$result); let assemblyFormat = "attr-dict (`:` $type^)?"; +} - let builders = [ - OpBuilder<(ins CArg<"Type", "Type()">:$ty), [{ - build($_builder, $_state, $_builder.getType(), - ty ? TypeAttr::get(ty) : TypeAttr()); - }]>, - ]; +//===----------------------------------------------------------------------===// +// pdl::TypesOp +//===----------------------------------------------------------------------===// + +def PDL_TypesOp : PDL_Op<"types"> { + let summary = "Define a range of type handles within a pattern"; + let description = [{ + `pdl.types` operations capture result type constraints of `Value`s, and + `Operation`s. Instances of this operation define results types of a given + entity. A `pdl.types` may partially constrain the results by specifying + an array of `Type`s. + + Example: + + ```mlir + // Define a range of types: + %types = pdl.types + + // Define a range of types with a range of constant values: + %types = pdl.types : [i32, i64, i32] + ``` + }]; + + let arguments = (ins OptionalAttr:$types); + let results = (outs PDL_RangeOf:$result); + let assemblyFormat = "attr-dict (`:` $types^)?"; } #endif // MLIR_DIALECT_PDL_IR_PDLOPS diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td @@ -101,4 +101,18 @@ CPred<"$_self.isa<::mlir::pdl::PDLType>()">, "pdl type", "::mlir::pdl::PDLType">; +// A range of positional values of one of the provided types. +class PDL_RangeOf : + ContainerType, PDL_Range.predicate, + "$_self.cast<::mlir::pdl::RangeType>().getElementType()", + "range", "::mlir::pdl::RangeType">, + BuildableType<"::mlir::pdl::RangeType::get(" # positionalType.builderCall # + ")">; + +// Either a positional value or a range of positional values for a given type. +class PDL_InstOrRangeOf : + AnyTypeOf<[positionalType, PDL_RangeOf], + "single element or range of " # positionalType.summary, + "::mlir::pdl::PDLType">; + #endif // MLIR_DIALECT_PDL_IR_PDLTYPES diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -35,13 +35,19 @@ /// Returns true if the given operation is used by a "binding" pdl operation /// within the main matcher body of a `pdl.pattern`. static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) { - for (Operation *user : op->getUsers()) { + for (OpOperand &use : op->getUses()) { + Operation *user = use.getOwner(); if (user->getBlock() != matcherBlock) continue; - if (isa(user)) + if (isa(user)) + return true; + // Only the first operand of RewriteOp may be bound to, i.e. the root + // operation of the pattern. + if (isa(user) && use.getOperandNumber() == 0) return true; // A result by itself is not binding, it must also be bound. - if (isa(user) && hasBindingUseInMatcher(user, matcherBlock)) + if (isa(user) && + hasBindingUseInMatcher(user, matcherBlock)) return true; } return false; @@ -107,6 +113,14 @@ return verifyHasBindingUseInMatcher(op); } +//===----------------------------------------------------------------------===// +// pdl::OperandsOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(OperandsOp op) { + return verifyHasBindingUseInMatcher(op); +} + //===----------------------------------------------------------------------===// // pdl::OperationOp //===----------------------------------------------------------------------===// @@ -177,18 +191,18 @@ if (isa(resultTypeOp)) continue; - // If the type is already constrained, there is nothing to do. - TypeOp typeOp = cast(resultTypeOp); - if (typeOp.type()) - continue; - // If the type operation was defined in the matcher and constrains the // result of an input operation, it can be used. auto constrainsInputOp = [rewriterBlock](Operation *user) { return user->getBlock() != rewriterBlock && isa(user); }; - if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp)) - continue; + if (TypeOp typeOp = dyn_cast(resultTypeOp)) { + if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) + continue; + } else if (TypesOp typeOp = dyn_cast(resultTypeOp)) { + if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp)) + continue; + } return op .emitOpError("must have inferable or constrained result types when " @@ -296,6 +310,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// pdl::ResultsOp +//===----------------------------------------------------------------------===// + +static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, + Type &resultType) { + if (!index) { + resultType = RangeType::get(p.getBuilder().getType()); + return success(); + } + if (p.parseArrow() || p.parseType(resultType)) + return failure(); + return success(); +} + +static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, + IntegerAttr index, Type resultType) { + if (index) + p << " -> " << resultType; +} + +static LogicalResult verify(ResultsOp op) { + if (!op.index() && op.getType().isa()) { + return op.emitOpError() << "expected `pdl.range` result type when " + "no index is specified, but got: " + << op.getType(); + } + return success(); +} + //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// @@ -340,6 +384,14 @@ op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`"); } +//===----------------------------------------------------------------------===// +// pdl::TypesOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(TypesOp op) { + return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`"); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -27,7 +27,7 @@ // CHECK: pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]] // CHECK: pdl_interp.finalize pdl.pattern : benefit(1) { - %root = pdl.operation "foo.op"() + %root = pdl.operation "foo.op" pdl.rewrite %root with "rewriter" } } @@ -69,7 +69,7 @@ pdl.pattern : benefit(1) { %input0 = pdl.operand %input1 = pdl.operand - %root = pdl.operation(%input0, %input1) + %root = pdl.operation(%input0, %input1 : !pdl.value, !pdl.value) %result0 = pdl.result 0 of %root pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) @@ -96,7 +96,7 @@ pdl.pattern : benefit(1) { %type = pdl.type : i64 %input = pdl.operand : %type - %root = pdl.operation(%input, %input) + %root = pdl.operation(%input, %input : !pdl.value, !pdl.value) pdl.rewrite %root with "rewriter" } } @@ -120,7 +120,7 @@ pdl.pattern : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root = pdl.operation -> %type1, %type2 + %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type) pdl.rewrite %root with "rewriter" } } @@ -149,11 +149,11 @@ pdl.pattern : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %inputOp = pdl.operation -> %type1, %type2 + %inputOp = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type) %result1 = pdl.result 0 of %inputOp %result2 = pdl.result 1 of %inputOp - %root = pdl.operation(%result1, %result2) + %root = pdl.operation(%result1, %result2 : !pdl.value, !pdl.value) pdl.rewrite %root with "rewriter" } } @@ -168,12 +168,12 @@ // CHECK: pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64] pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root = pdl.operation -> %type + %root = pdl.operation -> (%type : !pdl.type) pdl.rewrite %root with "rewriter" } pdl.pattern : benefit(1) { %type = pdl.type : i64 - %root = pdl.operation -> %type + %root = pdl.operation -> (%type : !pdl.type) pdl.rewrite %root with "rewriter" } } @@ -195,13 +195,13 @@ pdl.pattern : benefit(1) { %resultType = pdl.type pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type) - %root = pdl.operation -> %resultType + %root = pdl.operation -> (%resultType : !pdl.type) pdl.rewrite %root with "rewriter" } pdl.pattern : benefit(1) { %resultType = pdl.type - %apply = pdl.operation -> %resultType + %apply = pdl.operation -> (%resultType : !pdl.type) pdl.rewrite %apply with "rewriter" } } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -9,7 +9,7 @@ // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) pdl.pattern : benefit(1) { %input = pdl.operand - %root = pdl.operation "foo.op"(%input) + %root = pdl.operation "foo.op"(%input : !pdl.value) pdl.rewrite %root with "rewriter"[true](%input : !pdl.value) } } @@ -60,12 +60,12 @@ // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) pdl.pattern : benefit(1) { %operand = pdl.operand - %root = pdl.operation "foo.op"(%operand) + %root = pdl.operation "foo.op"(%operand : !pdl.value) pdl.rewrite %root { %type = pdl.type : i32 - %newOp = pdl.operation "foo.op"(%operand) -> %type + %newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type) %result = pdl.result 0 of %newOp - %newOp1 = pdl.operation "foo.op2"(%result) + %newOp1 = pdl.operation "foo.op2"(%result : !pdl.value) pdl.erase %root } } @@ -82,12 +82,12 @@ // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]]) pdl.pattern : benefit(1) { %operand = pdl.operand - %root = pdl.operation "foo.op"(%operand) + %root = pdl.operation "foo.op"(%operand : !pdl.value) pdl.rewrite %root { %type = pdl.type : i32 - %newOp = pdl.operation "foo.op"(%operand) -> %type + %newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type) %result = pdl.result 0 of %newOp - %newOp1 = pdl.operation "foo.op2"(%result) + %newOp1 = pdl.operation "foo.op2"(%result : !pdl.value) pdl.erase %root } } @@ -103,10 +103,10 @@ pdl.pattern : benefit(1) { %rootType = pdl.type %rootType1 = pdl.type - %root = pdl.operation "foo.op" -> %rootType, %rootType1 + %root = pdl.operation "foo.op" -> (%rootType, %rootType1 : !pdl.type, !pdl.type) pdl.rewrite %root { %newType1 = pdl.type - %newOp = pdl.operation "foo.op" -> %rootType, %newType1 + %newOp = pdl.operation "foo.op" -> (%rootType, %newType1 : !pdl.type, !pdl.type) pdl.replace %root with %newOp } } @@ -123,9 +123,9 @@ // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.rewrite %root { - %newOp = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.replace %root with %newOp } } @@ -142,11 +142,11 @@ // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]]) pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.rewrite %root { - %newOp = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> (%type : !pdl.type) %newResult = pdl.result 0 of %newOp - pdl.replace %root with (%newResult) + pdl.replace %root with (%newResult : !pdl.value) } } } @@ -178,10 +178,10 @@ // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]] pdl.pattern : benefit(1) { %type = pdl.type - %root = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.rewrite %root { %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type - %newOp = pdl.operation "foo.op" -> %newType + %newOp = pdl.operation "foo.op" -> (%newType : !pdl.type) pdl.replace %root with %newOp } } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -38,7 +38,7 @@ // expected-error@below {{expected only one of [`type`, `value`] to be set}} %attr = pdl.attribute : %type 10 - %op = pdl.operation "foo.op" {"attr" = %attr} -> %type + %op = pdl.operation "foo.op" {"attr" = %attr} -> (%type : !pdl.type) pdl.rewrite %op with "rewriter" } @@ -90,6 +90,20 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl::OperandsOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.operands + + %op = pdl.operation "foo.op" + pdl.rewrite %op with "rewriter" +} + +// ----- + //===----------------------------------------------------------------------===// // pdl::OperationOp //===----------------------------------------------------------------------===// @@ -116,13 +130,13 @@ // ----- pdl.pattern : benefit(1) { - %op = pdl.operation "foo.op"() + %op = pdl.operation "foo.op" pdl.rewrite %op { %type = pdl.type // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} // expected-note@below {{result type #0 was not constrained}} - %newOp = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> (%type : !pdl.type) } } @@ -163,9 +177,9 @@ pdl.pattern : benefit(1) { %type = pdl.type : i32 - %root = pdl.operation "foo.op" -> %type + %root = pdl.operation "foo.op" -> (%type : !pdl.type) pdl.rewrite %root { - %newOp = pdl.operation "foo.op" -> %type + %newOp = pdl.operation "foo.op" -> (%type : !pdl.type) %newResult = pdl.result 0 of %newOp // expected-error@below {{expected no replacement values to be provided when the replacement operation is present}} @@ -177,6 +191,19 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl::ResultsOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op" + // expected-error@below {{expected `pdl.range` result type when no index is specified, but got: '!pdl.value'}} + %results = "pdl.results"(%root) : (!pdl.operation) -> !pdl.value + pdl.rewrite %root with "rewriter" +} + +// ----- + //===----------------------------------------------------------------------===// // pdl::RewriteOp //===----------------------------------------------------------------------===// @@ -237,3 +264,17 @@ %op = pdl.operation "foo.op" pdl.rewrite %op with "rewriter" } + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::TypesOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.types + + %op = pdl.operation "foo.op" + pdl.rewrite %op with "rewriter" +} diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -8,12 +8,12 @@ // Operation with attributes and results. %attribute = pdl.attribute %type = pdl.type - %op0 = pdl.operation {"attr" = %attribute} -> %type + %op0 = pdl.operation {"attr" = %attribute} -> (%type : !pdl.type) %op0_result = pdl.result 0 of %op0 // Operation with input. %input = pdl.operand - %root = pdl.operation(%op0_result, %input) + %root = pdl.operation(%op0_result, %input : !pdl.value, !pdl.value) pdl.rewrite %root with "rewriter" } @@ -21,7 +21,7 @@ pdl.pattern @rewrite_with_args : benefit(1) { %input = pdl.operand - %root = pdl.operation(%input) + %root = pdl.operation(%input : !pdl.value) pdl.rewrite %root with "rewriter"(%input : !pdl.value) } @@ -36,7 +36,7 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) { %input = pdl.operand - %root = pdl.operation(%input) + %root = pdl.operation(%input : !pdl.value) pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value) } @@ -47,10 +47,10 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root = pdl.operation -> %type1, %type2 + %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type) pdl.rewrite %root { %type3 = pdl.type - %newOp = pdl.operation "foo.op" -> %type1, %type3 + %newOp = pdl.operation "foo.op" -> (%type1, %type3 : !pdl.type, !pdl.type) pdl.replace %root with %newOp } } @@ -58,12 +58,25 @@ // ----- // Check that the result type of an operation within a rewrite can be inferred -// from a pdl.replace. +// from types used within the match block. pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { %type1 = pdl.type : i32 %type2 = pdl.type - %root = pdl.operation -> %type1, %type2 + %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type) + pdl.rewrite %root { + %newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type) + } +} + +// ----- + +// Check that the result type of an operation within a rewrite can be inferred +// from types used within the match block. +pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { + %types = pdl.types + %root = pdl.operation -> (%types : !pdl.range) pdl.rewrite %root { - %newOp = pdl.operation "foo.op" -> %type1, %type2 + %otherTypes = pdl.types : [i32, i64] + %newOp = pdl.operation "foo.op" -> (%types, %otherTypes : !pdl.range, !pdl.range) } }