diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -486,7 +486,7 @@ let hasCanonicalizer = 1; } -def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> { +def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { let summary = "Determines if all input shapes are equal."; let description = [{ Given 1 or more input shapes, determine if all shapes are the exact same. @@ -503,6 +503,8 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasCanonicalizer = 1; } // TODO(tpopp): Support witness attributes and then make this ConstantLike. diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -293,6 +293,16 @@ patterns.insert(context); } +//===----------------------------------------------------------------------===// +// CstrEqOp +//===----------------------------------------------------------------------===// + +void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + // If equal, return true op + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // ConstSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -7,7 +7,19 @@ }) }]>>; +def AllInputShapesEq : ConstraintgetOperands(), [&](mlir::Value val) { + auto baseCase = $0.getOwner()->getOperand(0).getDefiningOp(); + auto other = val.getDefiningOp(); + return baseCase && other && baseCase.shape() == other.shape(); + }) + }]>>; + // Canonicalization patterns. def ConstantAssumingAll : Pat<(Shape_AssumingAllOp:$op $input), (Shape_TrueWitnessOp), [(AllInputsTrueWitnesses $op)] >; + +def ConstCstrEq : Pat<(Shape_CstrEqOp:$op $shapes), + (Shape_TrueWitnessOp), + [(AllInputShapesEq $op)] >; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -107,6 +107,145 @@ return %ret : !shape.shape } +// ----- +// cstr_eq with equal const_shapes can be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: true_witness + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [0, 1] + %cs1 = shape.const_shape [0, 1] + %cs2 = shape.const_shape [0, 1] + %0 = shape.cstr_eq %cs0, %cs1, %cs2 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_eq with unequal const_shapes cannot be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_eq + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [0, 1] + %cs1 = shape.const_shape [3, 1] + %0 = shape.cstr_eq %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// cstr_eq without const_shapes cannot be folded +// CHECK-LABEL: func @f +func @f(%arg0: !shape.shape, %arg1: !shape.shape) { + // CHECK-NEXT: shape.cstr_eq + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.cstr_eq %arg0, %arg1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// assuming_all with known true witnesses can be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.true_witness + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.true_witness + %1 = shape.true_witness + %2 = shape.true_witness + %3 = shape.assuming_all %0, %1, %2 + "consume.witness"(%3) : (!shape.witness) -> () + return +} + +// ----- +// assuming_all should not be removed if not all witnesses are statically true. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: %[[TRUE:.*]] = shape.true_witness + // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source" + // CHECK-NEXT: shape.assuming_all %[[TRUE]], %[[UNKNOWN]] + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.true_witness + %1 = "test.source"() : () -> !shape.witness + %2 = shape.assuming_all %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +} + +// ----- +// any can be replaced with a constant input if it has one. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: %[[CS:.*]] = shape.const_shape + // CHECK-NEXT: return %[[CS]] + %0 = shape.const_shape [2, 3, 4] + %1 = shape.any %0, %arg0 + return %1 : !shape.shape +} + + +// ----- +// any is not yet replaced without a constant input. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: shape.any + // CHECK-NEXT: return %[[CS]] + %1 = shape.any %arg0, %arg1 + return %1 : !shape.shape +} + +// ----- +// Broadcastable with broadcastable constant shapes can be removed. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: true_witness + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1, 1] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable with non-broadcastable constant shapes cannot be removed. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1, 3] + %cs1 = shape.const_shape [1, 5] + %0 = shape.cstr_broadcastable %cs0, %cs1 + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable without constant shapes cannot be removed. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_shape + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [1,3] + %0 = shape.cstr_broadcastable %arg0, %cs0 + "consume.witness"(%0) : (!shape.witness) -> () + return +} // ----- // assuming_all with known true witnesses can be folded // CHECK-LABEL: func @f