diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -138,16 +138,25 @@
 createSparseTensorCodegenPass(bool enableBufferInitialization);
 
 //===----------------------------------------------------------------------===//
-// The SparseTensorRewriting pass.
+// The SparseTensorPreRewriting pass.
 //===----------------------------------------------------------------------===//
 
-void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT,
-                                   bool enableForeach, bool enableConvert);
+void populateSparseTensorPreRewriting(RewritePatternSet &patterns);
 
-std::unique_ptr<Pass> createSparseTensorRewritePass();
-std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
-                                                    bool enableForeach = true,
-                                                    bool enableConvert = true);
+std::unique_ptr<Pass> createSparseTensorPreRewritePass();
+
+//===----------------------------------------------------------------------===//
+// The SparseTensorPostRewriting pass.
+//===----------------------------------------------------------------------===//
+
+void populateSparseTensorPostRewriting(RewritePatternSet &patterns,
+                                       bool enableRT, bool enableForeach,
+                                       bool enableConvert);
+
+std::unique_ptr<Pass> createSparseTensorPostRewritePass();
+std::unique_ptr<Pass>
+createSparseTensorPostRewritePass(bool enableRT, bool enableForeach = true,
+                                  bool enableConvert = true);
 
 //===----------------------------------------------------------------------===//
 // Other rewriting rules and passes.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -11,13 +11,13 @@
 
 include "mlir/Pass/PassBase.td"
 
-def SparseTensorRewrite : Pass<"sparse-tensor-rewrite", "ModuleOp"> {
+def SparseTensorPreRewrite : Pass<"sparse-tensor-pre-rewrite", "ModuleOp"> {
   let summary = "Applies sparse tensor rewriting rules prior to sparsification";
   let description = [{
     A pass that applies rewriting rules to sparse tensor operations prior
     to running the actual sparsification pass.
   }];
-  let constructor = "mlir::createSparseTensorRewritePass()";
+  let constructor = "mlir::createSparseTensorPreRewritePass()";
   let dependentDialects = [
     "arith::ArithDialect",
     "bufferization::BufferizationDialect",
@@ -26,14 +26,6 @@
     "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
-  let options = [
-    Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
-           "true", "Enable runtime library for manipulating sparse tensors">,
-    Option<"enableForeach", "enable-foreach", "bool",
-           "true", "Enable rewriting rules for the foreach operator">,
-    Option<"enableConvert", "enable-convert", "bool",
-           "true", "Enable rewriting rules for the convert operator">,
-  ];
 }
 
 def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
@@ -109,6 +101,31 @@
   ];
 }
 
+def SparseTensorPostRewrite : Pass<"sparse-tensor-post-rewrite", "ModuleOp"> {
+  let summary = "Applies sparse tensor rewriting rules after sparsification";
+  let description = [{
+    A pass that applies rewriting rules to sparse tensor operations after
+    running the actual sparsification pass.
+  }];
+  let constructor = "mlir::createSparseTensorPostRewritePass()";
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "bufferization::BufferizationDialect",
+    "linalg::LinalgDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+  let options = [
+    Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
+           "true", "Enable runtime library for manipulating sparse tensors">,
+    Option<"enableForeach", "enable-foreach", "bool",
+           "true", "Enable rewriting rules for the foreach operator">,
+    Option<"enableConvert", "enable-convert", "bool",
+           "true", "Enable rewriting rules for the convert operator">,
+  ];
+}
+
 def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
   let summary = "Convert sparse tensors and primitives to library calls";
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -57,8 +57,9 @@
           /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
   if (options.testBufferizationAnalysisOnly)
     return;
-  pm.addPass(createSparseTensorRewritePass(options.enableRuntimeLibrary));
+  pm.addPass(createSparseTensorPreRewritePass());
   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
+  pm.addPass(createSparseTensorPostRewritePass(options.enableRuntimeLibrary));
   if (options.enableRuntimeLibrary) {
     pm.addPass(createSparseTensorConversionPass(
         options.sparseTensorConversionOptions()));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -21,8 +21,9 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_SPARSETENSORREWRITE
+#define GEN_PASS_DEF_SPARSETENSORPREREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_SPARSETENSORPOSTREWRITE
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
@@ -38,22 +39,16 @@
 // Passes implementation.
 //===----------------------------------------------------------------------===//
 
-struct SparseTensorRewritePass
-    : public impl::SparseTensorRewriteBase<SparseTensorRewritePass> {
+struct SparseTensorPreRewritePass
+    : public impl::SparseTensorPreRewriteBase<SparseTensorPreRewritePass> {
 
-  SparseTensorRewritePass() = default;
-  SparseTensorRewritePass(const SparseTensorRewritePass &pass) = default;
-  SparseTensorRewritePass(bool enableRT, bool foreach, bool convert) {
-    enableRuntimeLibrary = enableRT;
-    enableForeach = foreach;
-    enableConvert = convert;
-  }
+  SparseTensorPreRewritePass() = default;
+  SparseTensorPreRewritePass(const SparseTensorPreRewritePass &pass) = default;
 
   void runOnOperation() override {
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
-    populateSparseTensorRewriting(patterns, enableRuntimeLibrary, enableForeach,
-                                  enableConvert);
+    populateSparseTensorPreRewriting(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
@@ -80,6 +75,27 @@
   }
 };
 
+struct SparseTensorPostRewritePass
+    : public impl::SparseTensorPostRewriteBase<SparseTensorPostRewritePass> {
+
+  SparseTensorPostRewritePass() = default;
+  SparseTensorPostRewritePass(const SparseTensorPostRewritePass &pass) =
+      default;
+  SparseTensorPostRewritePass(bool enableRT, bool foreach, bool convert) {
+    enableRuntimeLibrary = enableRT;
+    enableForeach = foreach;
+    enableConvert = convert;
+  }
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    populateSparseTensorPostRewriting(patterns, enableRuntimeLibrary,
+                                      enableForeach, enableConvert);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
 
@@ -254,15 +270,8 @@
 // Pass creation methods.
 //===----------------------------------------------------------------------===//
 
-std::unique_ptr<Pass> mlir::createSparseTensorRewritePass() {
-  return std::make_unique<SparseTensorRewritePass>();
-}
-
-std::unique_ptr<Pass> mlir::createSparseTensorRewritePass(bool enableRT,
-                                                          bool enableForeach,
-                                                          bool enableConvert) {
-  return std::make_unique<SparseTensorRewritePass>(enableRT, enableForeach,
-                                                   enableConvert);
+std::unique_ptr<Pass> mlir::createSparseTensorPreRewritePass() {
+  return std::make_unique<SparseTensorPreRewritePass>();
 }
 
 std::unique_ptr<Pass> mlir::createSparsificationPass() {
@@ -274,6 +283,17 @@
   return std::make_unique<SparsificationPass>(options);
 }
 
+std::unique_ptr<Pass> mlir::createSparseTensorPostRewritePass() {
+  return std::make_unique<SparseTensorPostRewritePass>();
+}
+
+std::unique_ptr<Pass>
+mlir::createSparseTensorPostRewritePass(bool enableRT, bool enableForeach,
+                                        bool enableConvert) {
+  return std::make_unique<SparseTensorPostRewritePass>(enableRT, enableForeach,
+                                                       enableConvert);
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1019,11 +1019,15 @@
 //===---------------------------------------------------------------------===//
 // Methods that add patterns described in this file to a pattern list.
 //===---------------------------------------------------------------------===//
-void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns,
-                                         bool enableRT, bool enableForeach,
-                                         bool enableConvert) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
-               ReshapeRewriter<tensor::ExpandShapeOp>,
+void mlir::populateSparseTensorPreRewriting(RewritePatternSet &patterns) {
+  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd>(
+      patterns.getContext());
+}
+
+void mlir::populateSparseTensorPostRewriting(RewritePatternSet &patterns,
+                                             bool enableRT, bool enableForeach,
+                                             bool enableConvert) {
+  patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
   if (enableForeach)
     patterns.add<ForeachRewriter>(patterns.getContext());
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \
 // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \
 // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -6,7 +6,7 @@
 // RUN: mlir-opt %s --sparse-tensor-conversion="s2s-strategy=0" \
 // RUN:    --canonicalize --cse | FileCheck %s -check-prefixes=CHECK-AUTO,CHECK
 
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-foreach=false" \
+// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-foreach=false" \
 // RUN: --canonicalize --cse | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector64 = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir
--- a/mlir/test/Dialect/SparseTensor/rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite | FileCheck %s
+// RUN: mlir-opt %s -sparse-tensor-post-rewrite | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = ["compressed"]
diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
--- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\
+// RUN: mlir-opt %s -sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" |\
 // RUN: FileCheck %s
 
 #CSR = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
--- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" \
 // RUN: --sparsification | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-tensor-pre-rewrite --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
 
 #DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
 // RUN: mlir-opt %s --sparse-tensor-conversion --cse --canonicalize | FileCheck %s --check-prefix=CHECK-CONV
-// RUN: mlir-opt %s --sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" \
+// RUN: mlir-opt %s --sparse-tensor-post-rewrite="enable-runtime-library=false enable-convert=false" \
 // RUN: --cse --canonicalize  | FileCheck %s --check-prefix=CHECK-RWT
 
 #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
--- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s  --tensor-copy-insertion --sparse-tensor-rewrite --sparsification --cse | FileCheck %s
+// RUN: mlir-opt %s  --tensor-copy-insertion --sparse-tensor-pre-rewrite --sparsification --cse | FileCheck %s
 
 #SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>