diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -225,8 +225,8 @@
       return getResult().getType().cast<RankedTensorType>();
     }
 
-    // Infer the shape of the result tensor given the static shapes
-    // and element type of the result tensor.
+    // Infer the shape of the result tensor given the type of the source tensor
+    // and paddings.
     static RankedTensorType inferResultType(RankedTensorType sourceType,
                                 ArrayRef<int64_t> staticLow,
                                 ArrayRef<int64_t> staticHigh);
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -47,7 +47,7 @@
 
 /// Convert reassociation indices to affine expressions.
 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
-    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices);
+    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
 
 /// Constructs affine maps out of Array<Array<AffineExpr>>.
 SmallVector<AffineMap, 4>
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1147,16 +1147,16 @@
 }
 SmallVector<ReassociationExprs, 4>
 TensorCollapseShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4>
 TensorExpandShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
@@ -1317,7 +1317,7 @@
   auto resultType = computeTensorReshapeCollapsedType(
       src.getType().cast<RankedTensorType>(),
       getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b, reassociation)));
+          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1330,7 +1330,7 @@
   auto resultType = computeTensorReshapeCollapsedType(
       src.getType().cast<RankedTensorType>(),
       getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b, reassociation)));
+          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1316,16 +1316,16 @@
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
@@ -1427,8 +1427,8 @@
                           ArrayRef<NamedAttribute> attrs) {
   auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
+      memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                      b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1439,8 +1439,8 @@
                             ArrayRef<NamedAttribute> attrs) {
   auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
+      memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                      b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1475,10 +1475,41 @@
   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
 }
 
+struct CollapseShapeOpMemRefCastFolder
+    : public OpRewritePattern<CollapseShapeOp> {
+public:
+  using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getOperand().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    Type newResultType = computeReshapeCollapsedType(
+        cast.getOperand().getType().cast<MemRefType>(),
+        op.getReassociationMaps());
+
+    if (newResultType == op.getResultType()) {
+      rewriter.updateRootInPlace(
+          op, [&]() { op.srcMutable().assign(cast.source()); });
+    } else {
+      Value newOp = rewriter.create<CollapseShapeOp>(
+          op->getLoc(), cast.source(), op.getReassociationIndices());
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<CollapseReshapeOps<CollapseShapeOp>,
-              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+              CollapseShapeOpMemRefCastFolder>(context);
 }
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   if (succeeded(foldMemRefCast(*this)))
@@ -1486,8 +1517,6 @@
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
 }
 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
 }
 
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -183,13 +183,13 @@
 
 SmallVector<SmallVector<AffineExpr, 2>, 2>
 mlir::convertReassociationIndicesToExprs(
-    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
+    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
   for (const auto &indices : reassociationIndices) {
     SmallVector<AffineExpr, 2> reassociationMap;
     reassociationMap.reserve(indices.size());
     for (int64_t index : indices)
-      reassociationMap.push_back(b.getAffineDimExpr(index));
+      reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
     reassociationMaps.push_back(std::move(reassociationMap));
   }
   return reassociationMaps;
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -511,3 +511,31 @@
 }
 // CHECK-LABEL: @fold_memref_reshape_dynamic
 //   CHECK-NOT:   linalg.{{.*}}_shape
+
+// -----
+
+// CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK-SAME:         {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x1xf32> into memref<?x512xf32>
+// CHECK:           %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
+// CHECK-SAME:         memref<?x512xf32> to memref<?x?xf32>
+// CHECK:           return %[[DYNAMIC]] : memref<?x?xf32>
+// CHECK:         }
+func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> memref<?x?xf32> {
+  %dynamic = memref.cast %arg0: memref<?x512x1x1xf32> to memref<?x?x?x?xf32>
+  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
+  return %collapsed : memref<?x?xf32>
+}
+
+// CHECK-LABEL:   func @collapse_after_memref_cast(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK_SAME:        {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x?xf32> into memref<?x?xf32>
+// CHECK:           return %[[COLLAPSED]] : memref<?x?xf32>
+func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf32> {
+  %dynamic = memref.cast %arg0: memref<?x512x1x?xf32> to memref<?x?x?x?xf32>
+  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
+  return %collapsed : memref<?x?xf32>
+}
+