diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -136,9 +136,11 @@ The pattern is position-based: the symbol names used for capturing here do not need to match with the op definition as shown in the above example. As another -example, the pattern can be written as ` def : Pat<(AOp $a, F32Attr:$b), ...>;` +example, the pattern can be written as `def : Pat<(AOp $a, F32Attr:$b), ...>;` and use `$a` and `$b` to refer to the captured input and attribute. But using -the ODS name directly in the pattern is also allowed. +the ODS name directly in the pattern is also allowed. Operands in the source +pattern could be bound to the same name which enforces automatic equality on +those operands. Also note that we only need to add `TypeConstraint` or `AttributeConstraint` when we need to further limit the match criteria. If all valid cases to the op diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2257,7 +2257,9 @@ // In the source pattern, `argN` can be used to specify matchers (e.g., using // type/attribute type constraints, etc.) and bound to a name for later use. // We can also bound names to op instances to reference them later in -// multi-entity constraints. +// multi-entity constraints. Operands in the source pattern could be +// bound to the same name which enforces automatic equality on those operands. +// // // In the result pattern, `argN` can be used to refer to a previously bound // name, with potential transformations (e.g., using tAttr, etc.). `argN` can @@ -2267,16 +2269,19 @@ // For example, // // ``` -// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1), +// def : Pattern<(OneResultOp1:$op1 $arg0, $arg1, $arg0), // [(OneResultOp2:$op2 $arg0, $arg1), // (OneResultOp3 $op2 (OneResultOp4))], // [(HasStaticShapePred $op1)]>; // ``` // -// `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to -// build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to -// check whether the result's shape is static. `$op2` is bound to -// `OneResultOp2` and used to build `OneResultOp3`. +// First `$arg0` and '$arg1' are bound to the `OneResultOp1`'s first +// and second arguments and used later to build `OneResultOp2`. Second '$arg0' +// is bound to the `OneResultOp1`'s third argument. Equality is automatically +// enforced on first and third arguments of OneResultOp1. +// `$op1` is bound to `OneResultOp1` and used to check whether the result's +// shape is static. `$op2` is bound to `OneResultOp2` and used to +// build `OneResultOp3`. // // ## Multi-result op // diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -1,7 +1,5 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" -def EqualBinaryOperands : Constraint>; - def AllInputShapesEq : Constraint; +def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x), + (Shape_ConstWitnessOp ConstBoolAttrTrue)>; def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), (Shape_ConstWitnessOp ConstBoolAttrTrue),