diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -470,34 +470,26 @@
 //===----------------------------------------------------------------------===//
 
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
-  if (operands.size() == 1)
+  if (shapes().size() == 1)
     return shapes().front();
 
   // TODO: Support folding with more than 2 input shapes
   if (shapes().size() > 2)
     return nullptr;
 
-  if (!operands[1])
-    return nullptr;
-
-  auto rhsShape = llvm::to_vector<6>(
-      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  if (rhsShape.empty())
-    return shapes()[0];
-
-  if (!operands[0])
+  if (!operands[0] || !operands[1])
     return nullptr;
-
   auto lhsShape = llvm::to_vector<6>(
       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
-  if (lhsShape.empty())
-    return shapes()[1];
-
+  auto rhsShape = llvm::to_vector<6>(
+      operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   SmallVector<int64_t, 6> resultShape;
+
   // If the shapes are not compatible, we can't fold it.
   // TODO: Fold to an "error".
   if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
     return nullptr;
+
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
 }
@@ -531,6 +523,31 @@
   }
 };
 
+template <typename OpTy>
+struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    auto isPotentiallyNonEmptyShape = [](Value shape) {
+      if (auto constShape = shape.getDefiningOp<ConstShapeOp>())
+        return constShape.shape().size() != 0;
+      return true;
+    };
+    auto newOperands = llvm::to_vector<8>(
+        llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
+
+    // Reduce op to equivalent without empty shape operands.
+    if (newOperands.size() < op.getNumOperands()) {
+      rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
+                                        op->getAttrs());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 struct BroadcastForwardSingleOperandPattern
     : public OpRewritePattern<BroadcastOp> {
   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
@@ -538,18 +555,59 @@
   LogicalResult matchAndRewrite(BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (op.getNumOperands() == 1) {
-      rewriter.replaceOp(op, op.shapes().front());
+      Value uniqueShapeOperand = op.shapes().front();
+      rewriter.replaceOp(op, uniqueShapeOperand);
       return success();
     }
     return failure();
   }
 };
+
+struct BroadcastFoldConstantOperandsPattern
+    : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<int64_t, 8> foldedConstantShape;
+    SmallVector<Value, 8> newShapeOperands;
+    for (Value shape : op.shapes()) {
+      if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
+        SmallVector<int64_t, 8> newFoldedConstantShape;
+        if (OpTrait::util::getBroadcastedShape(
+                foldedConstantShape,
+                llvm::to_vector<8>(constShape.shape().getValues<int64_t>()),
+                newFoldedConstantShape)) {
+          foldedConstantShape = newFoldedConstantShape;
+          continue;
+        }
+      }
+      newShapeOperands.push_back(shape);
+    }
+
+    // Need at least two constant operands to fold anything.
+    if (op.getNumOperands() - newShapeOperands.size() < 2)
+      return failure();
+
+    auto foldedConstantOperandsTy = RankedTensorType::get(
+        {static_cast<int64_t>(foldedConstantShape.size())},
+        rewriter.getIndexType());
+    newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
+        op.getLoc(), foldedConstantOperandsTy,
+        rewriter.getIndexTensorAttr(foldedConstantShape)));
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
+                                             newShapeOperands);
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<BroadcastForwardSingleOperandPattern,
-               RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
+  patterns.add<BroadcastFoldConstantOperandsPattern,
+               BroadcastForwardSingleOperandPattern,
+               RemoveDuplicateOperandsPattern<BroadcastOp>,
+               RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -120,6 +120,36 @@
 
 // -----
 
+// All but one operands are known empty shapes.
+// CHECK-LABEL: @all_but_one_empty
+// CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)
+func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: return %[[ARG]]
+  %0 = shape.const_shape [] : !shape.shape
+  %1 = shape.const_shape [] : tensor<0xindex>
+  %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+      tensor<0xindex>, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+
+// Partial folding.
+// CHECK-LABEL: @partial_folding
+// CHECK-SAME:  (%[[ARG:.*]]: !shape.shape)
+func @partial_folding(%arg0 : !shape.shape) -> !shape.shape {
+  // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex>
+  // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape
+  // CHECK: return %[[RESULT]]
+  %0 = shape.const_shape [2, 1] : !shape.shape
+  %1 = shape.const_shape [1, 2, 3] : tensor<3xindex>
+  %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape,
+      tensor<3xindex>, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+
+// -----
+
 // Incompatible shapes. No folding.
 // CHECK-LABEL: func @f
 func @f() -> !shape.shape {