diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -695,7 +695,8 @@ // folding in case additional shape information is inferred at some point that // does not result in folding. patterns.add>(context); + RemoveDuplicateOperandsPattern, + RemoveEmptyShapeOperandsPattern>(context); } // Return true if there is exactly one attribute not representing a scalar diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -638,6 +638,20 @@ return } +// ----- +// Empty shape arguments can be removed from broadcastable ops. +// CHECK-LABEL: func @f +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +func @f(%arg0 : tensor, %arg1 : tensor) { + // CHECK-NOT: const_shape + // CHECK: cstr_broadcastable %[[ARG0]], %[[ARG1]] : tensor, tensor + %0 = shape.const_shape [] : !shape.shape + %1 = shape.cstr_broadcastable %arg0, %arg1, %0 + : tensor, tensor, !shape.shape + "consume.witness"(%1) : (!shape.witness) -> () + return +} + // ----- // Broadcastable with non-broadcastable constant shapes is always false // CHECK-LABEL: func @static_non_broadcastable