diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -553,7 +553,8 @@ // Canonicalization patterns have overlap with the considerations during // folding in case additional shape information is inferred at some point that // does not result in folding. - patterns.add(context); + patterns.add>(context); } // Return true if there is exactly one attribute not representing a scalar diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -648,7 +648,7 @@ // CHECK: shape.cstr_broadcastable // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs0 = shape.const_shape [8, 1] : !shape.shape %cs1 = shape.const_shape [1, 8] : !shape.shape %cs2 = shape.const_shape [1, -1] : !shape.shape %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape @@ -665,7 +665,7 @@ // CHECK-NEXT: return %cs0 = shape.const_shape [8, 1] : !shape.shape %cs1 = shape.const_shape [1, -1] : !shape.shape - %cs2 = shape.const_shape [1, -1] : !shape.shape + %cs2 = shape.const_shape [8, -1] : !shape.shape %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape "consume.witness"(%0) : (!shape.witness) -> () return @@ -1097,6 +1097,19 @@ // ----- +// CHECK-LABEL: @cstr_broadcastable_on_duplicate_shapes +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape) +func @cstr_broadcastable_on_duplicate_shapes(%a : !shape.shape, + %b : !shape.shape) -> !shape.witness { + // CHECK: %[[RES:.*]] = shape.cstr_broadcastable %[[A]], %[[B]] : + // CHECK: return %[[RES]] + %0 = shape.cstr_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape, + !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape + return %0 : !shape.witness +} + +// ----- + // CHECK-LABEL: @broadcast_on_same_shape // CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape {