diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -156,7 +156,8 @@
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
         hasSparseOut(false),
-        dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)) {}
+        dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)),
+        loopIdxToDim(t + 1, std::vector<Optional<unsigned>>(l, llvm::None)) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -257,10 +258,31 @@
     return getDimLevelType(tensor(b), index(b));
   }
 
+  /// Gets the dimension number of the `i`th loop of the `t`th tensor.
+  Optional<unsigned> getDimNum(unsigned t, unsigned i) const {
+    assert(t < numTensors && i < numLoops);
+    return loopIdxToDim[t][i];
+  }
+
+  /// Gets the dimension number of `b`.
+  Optional<unsigned> getDimNum(unsigned b) const {
+    return getDimNum(tensor(b), index(b));
+  }
+
   /// Sets the dimension level type of the `i`th loop of the `t`th tensor.
-  void setDimLevelType(unsigned t, unsigned i, DimLevelType d) {
-    assert(isValidDLT(d));
-    dimTypes[t][i] = d;
+  void setDimLevelType(unsigned t, unsigned i, unsigned dim, DimLevelType dlt) {
+    assert(isValidDLT(dlt));
+    dimTypes[t][i] = dlt;
+    loopIdxToDim[t][i] = dim;
+  }
+
+  void foreachTidDimPairInBits(
+      const BitVector &bits,
+      function_ref<void(unsigned b, unsigned tid, Optional<unsigned> dim,
+                        DimLevelType dlt)>
+          cb) {
+    for (unsigned b : bits.set_bits())
+      cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
   }
 
   // Has sparse output tensor setter.
@@ -310,7 +332,11 @@
   const unsigned numTensors;
   const unsigned numLoops;
   bool hasSparseOut;
+  // Map that converts pair<tensor id, loop id> to the corresponding dimension
+  // level type.
   std::vector<std::vector<DimLevelType>> dimTypes;
+  // Map that converts pair<tensor id, loop id> to the corresponding dimension.
+  std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -40,8 +40,6 @@
 
 namespace {
 
-constexpr unsigned INVALID_ID = std::numeric_limits<unsigned>::max();
-
 // Iteration graph sorting.
 enum SortMask {
   kSparseOnly = 0x0,
@@ -83,14 +81,6 @@
   // Topsort (reference should remain in scope).
   std::vector<unsigned> &topSort;
 
-  // From tensor id + loop id => dim id.
-  // TODO: This map should probably be maintained by Merger (it can be set up
-  // together with dimLvlType Map).
-  std::vector<std::vector<unsigned>> loopIdxToDim;
-
-  // Initialize the above two mapping.
-  void buildLoopIdxToDimMap(linalg::GenericOp op);
-
   Value getLoopIdxValue(size_t loopIdx) const {
     for (unsigned lv = 0; lv < topSort.size(); lv++)
       if (topSort[lv] == loopIdx)
@@ -100,30 +90,6 @@
   }
 };
 
-void CodeGen::buildLoopIdxToDimMap(linalg::GenericOp op) {
-  size_t numLoops = op.getNumLoops();
-  size_t numTensors = op.getNumOperands();
-  loopIdxToDim.assign(numTensors, std::vector<unsigned>(numLoops, INVALID_ID));
-
-  for (OpOperand &t : op->getOpOperands()) {
-    auto map = op.getMatchingIndexingMap(&t);
-    auto enc = getSparseTensorEncoding(t.get().getType());
-    // Scan all dimensions of current tensor.
-    unsigned tid = t.getOperandNumber();
-    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
-      auto a = map.getResult(toOrigDim(enc, d)).dyn_cast<AffineDimExpr>();
-      if (a) {
-        unsigned loopId = a.getPosition();
-        // Fills the mapping.
-        loopIdxToDim[tid][loopId] = d;
-      }
-      // Else a compound affine, do nothing. (at least we are good for
-      // now, as we only support compound affine expr on non-annoated dense
-      // tensors).
-    }
-  }
-}
-
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -151,8 +117,9 @@
 /// Helper method to inspect affine expressions. Rejects cases where the
 /// same index is used more than once. Also rejects compound affine
 /// expressions in sparse dimensions.
-static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
-                       DimLevelType dim, bool setLvlFormat = true) {
+static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
+                       AffineExpr a, DimLevelType dlt,
+                       bool setLvlFormat = true) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
@@ -160,21 +127,21 @@
       return false; // used more than once
 
     if (setLvlFormat)
-      merger.setDimLevelType(tensor, idx, dim);
+      merger.setDimLevelType(tensor, idx, dim, dlt);
     return true;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
-    if (!isDenseDLT(dim))
+    if (!isDenseDLT(dlt))
       return false; // compound only in dense dim
     auto binOp = a.cast<AffineBinaryOpExpr>();
     // We do not set dim level format for affine expresssion like d0 + d1 on
     // both loop index at d0 and d1,
-    return findAffine(merger, tensor, binOp.getLHS(), dim, false) &&
-           findAffine(merger, tensor, binOp.getRHS(), dim, false);
+    return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, false) &&
+           findAffine(merger, tensor, dim, binOp.getRHS(), dlt, false);
   }
   case AffineExprKind::Constant:
-    return isDenseDLT(dim); // const only in dense dim
+    return isDenseDLT(dlt); // const only in dense dim
   default:
     return false;
   }
@@ -196,7 +163,7 @@
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      if (!findAffine(merger, tensor, a, getDimLevelType(enc, d)))
+      if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d)))
         return false; // inadmissible affine expression
     }
   }
@@ -1024,8 +991,7 @@
     Value clause;
     if (isCompressedDLT(merger.getDimLevelType(b)) ||
         isSingletonDLT(merger.getDimLevelType(b))) {
-      auto dim = codegen.loopIdxToDim[tensor][idx];
-      assert(dim != INVALID_ID);
+      auto dim = merger.getDimNum(tensor, idx).value();
       Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
       Value op2 = codegen.getLoopIdxValue(idx);
       clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
@@ -1082,23 +1048,22 @@
   unsigned l0 = merger.set(lts)[0];
   bool needsUniv = false;
 
-  SmallVector<size_t, 4> ts;
-  SmallVector<size_t, 4> ds;
-  for (auto b : merger.lat(l0).bits.set_bits()) {
-    if (isDenseDLT(merger.getDimLevelType(b)) ||
-        isUndefDLT(merger.getDimLevelType(b))) {
-      needsUniv = true;
-    } else {
-      unsigned tensor = merger.tensor(b);
-      assert(idx == merger.index(b));
-      size_t dim = codegen.loopIdxToDim[tensor][idx];
-      assert(dim != INVALID_ID);
-      ts.push_back(tensor);
-      ds.push_back(dim);
-    }
-  }
+  SmallVector<size_t> tids;
+  SmallVector<size_t> dims;
+  merger.foreachTidDimPairInBits(
+      merger.lat(l0).bits,
+      [&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
+        assert(merger.index(b) == idx);
+        if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+          needsUniv = true;
+        } else {
+          // sparse/singleton dim levels.
+          tids.push_back(tid);
+          dims.push_back(dim.value());
+        }
+      });
 
-  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), ts, ds);
+  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
 
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
@@ -1119,44 +1084,44 @@
                                        SmallVectorImpl<size_t> &condDims,
                                        SmallVectorImpl<size_t> &extraTids,
                                        SmallVectorImpl<size_t> &extraDims) {
-  const BitVector &simple = merger.lat(li).simple;
   const BitVector &all = merger.lat(li).bits;
-  assert(simple.size() == all.size());
-  // First converts bits to array + dim pair
-  for (unsigned b = 0, e = simple.size(); b < e; b++) {
-    size_t tid = merger.tensor(b);
+  const BitVector &simple = merger.lat(li).simple;
+
+  // Converts bits to array + dim pair
+  merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+                                               Optional<unsigned> dim,
+                                               DimLevelType dlt) {
     if (simple.test(b)) {
-      // the simplified condition must be a subset of the original condition.
-      assert(all.test(b));
-      assert(merger.index(b) == idx);
-      if (isUndefDLT(merger.getDimLevelType(b))) {
-        // This could be a synthetic tensor (for invariants and sparse output
-        // tensor).
+      if (isUndefDLT(dlt)) {
+        // This could be a synthetic tensor (for invariants and sparse
+        // output tensor).
         // In both cases, we mean to generate loops over output tensor.
-        // e.g.,
-        // out[i][j] = invariant;
-        if (merger.getSynTensorID() == tid)
+        // E.g., out[i][j] = invariant;
+        if (merger.getSynTensorID() == tid) {
           tid = merger.getOutTensorID();
+          dim = merger.getDimNum(tid, idx);
+        }
+        // Skips invalid dim (e.g., when this is a zero ranked tensor).
+        if (!dim)
+          return;
       }
-      auto dim = codegen.loopIdxToDim[tid][idx];
-      if (dim != INVALID_ID) {
-        // dim could be invalid if this is a zero ranked tensor
-        condTids.push_back(tid);
-        condDims.push_back(dim);
-      }
-    } else if ((all.test(b) || merger.isOutTensor(b, idx)) &&
-               isDenseDLT(merger.getDimLevelType(b))) {
-      assert(merger.index(b) == idx);
-      // Note that we generate dense indices of the output tensor
-      // unconditionally, since they may not appear in the lattice, but may be
-      // needed for linearized codegen.
-      // Only dense dimensions should be optimized from conditions.
-      assert(isDenseDLT(merger.getDimLevelType(b)));
-      auto dim = codegen.loopIdxToDim[tid][idx];
-      assert(dim != INVALID_ID);
+      condTids.push_back(tid);
+      condDims.push_back(dim.value());
+    } else if (isDenseDLT(dlt)) {
+      // TODO: get rid of extraTids and extraDims.
       extraTids.push_back(tid);
-      extraDims.push_back(dim);
+      extraDims.push_back(dim.value());
     }
+  });
+
+  if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
+    // Note that we generate dense indices of the output tensor
+    // unconditionally, since they may not appear in the lattice, but may be
+    // needed for linearized codegen.
+    // Only dense dimensions should be optimized from conditions.
+    auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
+    extraTids.push_back(merger.getOutTensorID());
+    extraDims.push_back(dim);
   }
 }
 
@@ -1370,8 +1335,6 @@
     // Recursively generates code if admissible.
     CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut,
                     outerParNest, topSort);
-    // TODO: maybe merger should be responsible of maintaining the map.
-    codegen.buildLoopIdxToDimMap(op);
     genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, exp, 0);
     genResult(merger, codegen, rewriter, op);
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -313,15 +313,15 @@
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t2, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -338,19 +338,19 @@
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: sparse input vector
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t2, l0, 0, DimLevelType::Compressed);
 
     // Tensor 3: dense output vector
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t3, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -371,15 +371,15 @@
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t2, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -400,19 +400,19 @@
 
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
+    merger.setDimLevelType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: undef input vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Undef);
+    merger.setDimLevelType(t2, l0, 0, DimLevelType::Undef);
 
     // Tensor 3: dense output vector.
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
+    merger.setDimLevelType(t3, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -436,15 +436,15 @@
 
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
+    merger.setDimLevelType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: undef input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Undef);
+    merger.setDimLevelType(t1, l0, 0, DimLevelType::Undef);
 
     // Tensor 2: sparse output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
+    merger.setDimLevelType(t2, l0, 0, DimLevelType::Compressed);
   }
 };