diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -336,9 +336,32 @@
 // AssumingAllOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+struct AssumingAllToCstrBroadcastableCanonicalization
+    : public OpRewritePattern<AssumingAllOp> {
+  using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingAllOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value, 8> shapes;
+    for (Value v : op.inputs()) {
+      auto cstrBcastableOp = v.getDefiningOp<CstrBroadcastableOp>();
+      if (!cstrBcastableOp)
+        return failure();
+      auto range = cstrBcastableOp.shapes();
+      shapes.append(range.begin(), range.end());
+    }
+    rewriter.replaceOpWithNewOp<CstrBroadcastableOp>(op, shapes);
+    return success();
+  }
+};
+} // namespace
+
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AssumingAllOneOp>(context);
+  patterns
+      .add<AssumingAllOneOp, AssumingAllToCstrBroadcastableCanonicalization>(
+          context);
 }
 
 OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -433,6 +433,20 @@
   return
 }
 
+// -----
+// `assuming_all` with all `cstr_broadcastable` can be collapsed.
+// CHECK-LABEL: func @assuming_all_to_cstr_broadcastable
+// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<3xindex>)
+func @assuming_all_to_cstr_broadcastable(%a : !shape.shape,
+    %b : tensor<?xindex>, %c : tensor<3xindex>) -> !shape.witness {
+  // CHECK: %[[RESULT:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]]
+  // CHECK: return %[[RESULT]]
+  %0 = shape.cstr_broadcastable %a, %b : !shape.shape, tensor<?xindex>
+  %1 = shape.cstr_broadcastable %b, %c : tensor<?xindex>, tensor<3xindex>
+  %2 = shape.assuming_all %0, %1
+  return %2 : !shape.witness
+}
+
 // -----
 // assuming_all with known passing witnesses can be folded
 // CHECK-LABEL: func @f