diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -834,11 +834,77 @@
                                     /*asProducer=*/false);
   }
 
+  static LinalgOp propagateReshapeOpToOperands(LinalgOp producer,
+                                               TensorReshapeOp consumer,
+                                               unsigned consumerIdx,
+                                               PatternRewriter &rewriter) {
+    if (isa<IndexedGenericOp>(producer.getOperation()))
+      return nullptr;
+
+    // All the shapes are the same for the generic op.
+    auto types = producer.getInputOutputShapedTypes();
+    if (!types.empty() &&
+        !std::equal(types.begin() + 1, types.end(), types.begin())) {
+      return nullptr;
+    }
+
+    for (auto indexingMap : producer.getIndexingMaps()) {
+      if (!indexingMap.isIdentity())
+        return nullptr;
+    }
+
+    // All loops are parallel loops.
+    for (auto iteratorType : producer.iterator_types()) {
+      if (iteratorType.cast<StringAttr>().getValue() !=
+          getParallelIteratorTypeName())
+        return nullptr;
+    }
+
+    Location loc = producer.getLoc();
+    auto dstShape = consumer.getResultType().cast<ShapedType>().getShape();
+    SmallVector<Value, 4> args;
+    for (auto arg : producer.getOperation()->getOperands()) {
+      auto type = RankedTensorType::get(
+          dstShape, arg.getType().cast<ShapedType>().getElementType());
+      if (auto cstOp = arg.getDefiningOp<ConstantOp>()) {
+        args.push_back(rewriter.create<ConstantOp>(
+            loc, cstOp.value().cast<DenseElementsAttr>().reshape(type)));
+      } else {
+        args.push_back(rewriter.create<linalg::TensorReshapeOp>(
+            loc, type, arg, consumer.reassociation()));
+      }
+    }
+
+    SmallVector<Type, 4> resultTypes;
+    for (auto t : producer.getOutputTensorTypes()) {
+      Type type = RankedTensorType::get(dstShape,
+                                        t.cast<ShapedType>().getElementType());
+      resultTypes.push_back(type);
+    }
+
+    int rank = dstShape.size();
+    int numArgsIn = producer.getNumInputs();
+    int numArgsOut = producer.getNumOutputs();
+    auto genericOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTypes, args, numArgsIn, numArgsOut,
+        SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
+                                  rewriter.getMultiDimIdentityMap(rank)),
+        SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
+    Region &region = genericOp.getRegion();
+    rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region,
+                               region.begin());
+    return cast<LinalgOp>(genericOp.getOperation());
+  }
+
   static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
                        unsigned consumerIdx, PatternRewriter &rewriter,
                        OperationFolder *folder = nullptr) {
-    if (!isFusible(producer, consumer, consumerIdx))
-      return nullptr;
+    if (!isFusible(producer, consumer, consumerIdx)) {
+      // If the reshape op can not fuse to producer directly, try to make the
+      // reshape op to be as producer case.
+      return propagateReshapeOpToOperands(producer, consumer, consumerIdx,
+                                          rewriter);
+    }
 
     // The indexing_maps for the operands of the fused operation are same as
     // those for the operands of the producer.
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -222,6 +222,37 @@
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d2)>
+
+func generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
+                                            -> tensor<8x33x4xf32> {
+    %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
+    %0 = linalg.generic
+      {args_in = 2 : i64, args_out = 1 : i64,
+       indexing_maps = [#map0, #map0, #map0],
+       iterator_types = ["parallel", "parallel"]}
+      %arg0, %cst {
+      ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+        %2 = mulf %arg1, %arg2 : f32
+        linalg.yield %2 : f32
+      }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32>
+    %1 = linalg.tensor_reshape %0 [#map1, #map2] :
+      tensor<264x4xf32> into tensor<8x33x4xf32>
+    return %1 : tensor<8x33x4xf32>
+  }
+}
+
+// The reshape op in `%arg0` is folded into the indexing map of generic op.
+// CHECK-LABEL: func generic_op_reshape_consumer_expanding
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.generic
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> (d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>