diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -277,7 +277,9 @@ operations correspond to input operations, or those that already existing within the MLIR module. Inside of a `pdl.rewrite`, these operations correspond to operations that should be created as part of the replacement - sequence. + sequence. The `commutative` attribute if present, specifies to the matcher + that the operation should be treated as commutative with respect to all its + operands. `pdl.operation`s are composed of a name, and a set of attribute, operand, and result type values, that map to what those that would be on a @@ -365,16 +367,22 @@ %otherResults = pdl.types %op = pdl.operation "foo.op" -> (%result, %otherResults : !pdl.type, !pdl.range) ``` + + ```mlir + // If the operation is commutative: + %op = pdl.operation "foo.op" commutative (%operand : !pdl.value) -> (%result : !pdl.type) + ``` }]; let arguments = (ins OptionalAttr:$name, + OptionalAttr:$commutative, Variadic>:$operands, Variadic:$attributes, StrArrayAttr:$attributeNames, Variadic>:$types); let results = (outs PDL_Operation:$op); let assemblyFormat = [{ - ($name^)? (`(` $operands^ `:` type($operands) `)`)? + ($name^)? (`commutative` $commutative^)? (`(` $operands^ `:` type($operands) `)`)? custom($attributes, $attributeNames) (`->` `(` $types^ `:` type($types) `)`)? attr-dict }]; @@ -386,9 +394,10 @@ CArg<"ValueRange", "llvm::None">:$attrValues, CArg<"ValueRange", "llvm::None">:$resultTypes), [{ auto nameAttr = name ? $_builder.getStringAttr(*name) : StringAttr(); + auto commutativeAttr = UnitAttr(); build($_builder, $_state, $_builder.getType(), nameAttr, - operandValues, attrValues, $_builder.getStrArrayAttr(attrNames), - resultTypes); + commutativeAttr, operandValues, attrValues, + $_builder.getStrArrayAttr(attrNames), resultTypes); }]>, ]; let extraClassDeclaration = [{ diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -157,3 +157,14 @@ apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) } } + +// ----- + +// Check that the optional attribute for commutative operations is working. +pdl.pattern @operations : benefit(1) { + %type = pdl.type + %input0 = pdl.operand + %input1 = pdl.operand + %root = pdl.operation "foo.op" commutative (%input0, %input1 : !pdl.value, !pdl.value) ->(%type : !pdl.type) + pdl.rewrite %root with "rewriter" +}