diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -69,12 +69,23 @@
 template <typename T, unsigned N>
 void applyPermutationToVector(SmallVector<T, N> &inVec,
                               ArrayRef<int64_t> permutation) {
+  assert(inVec.size() == permutation.size());
   SmallVector<T, N> auxVec(inVec.size());
   for (const auto &en : enumerate(permutation))
     auxVec[en.index()] = inVec[en.value()];
   inVec = auxVec;
 }
 
+template <typename T, unsigned N>
+void undoPermutationToVector(SmallVector<T, N> &inVec,
+                             ArrayRef<int64_t> permutation) {
+  assert(inVec.size() == permutation.size());
+  SmallVector<T, N> auxVec = llvm::to_vector(inVec);
+  for (const auto &en : llvm::enumerate(permutation))
+    auxVec[en.value()] = inVec[en.index()];
+  inVec = auxVec;
+}
+
 /// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -57,10 +57,12 @@
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRDialectUtils
   MLIRIR
   MLIRLinalgDialect
   MLIRSCFDialect
   MLIRSupport
   MLIRTensorDialect
+  MLIRTensorUtils
   MLIRTilingInterface
   )
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 using namespace mlir;
@@ -68,6 +70,194 @@
   }
 };
 
+/// Returns the permutation based on `dimsPos` and `rank`. The `dimsPos` is a
+/// projected permutation which contains a list of position. The method returns
+/// the reverse map of the projected permutation.
+/// E.g., dimsPos = [2, 0] and rank = 2, the method returnssss [1, 0].
+static SmallVector<int64_t>
+computeInversePermutationForDimsPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
+  SmallVector<int64_t> inVec;
+  inVec.reserve(dimsPos.size());
+  // First map dims and their position. For example, dimsPos = [2, 0] will map
+  // to:
+  // [
+  //  [ key: 2, value: 0]
+  //  [ key: 0, value: 1]
+  // ]
+  // where key is the idx in dims_pos while value its position in dims_pos.
+  DenseMap<int64_t, int64_t> dimsAndPosMapping;
+  for (auto dimsIdx : llvm::seq<int64_t>(0, dimsPos.size()))
+    dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+
+  // Scan the position in order and insert the value in the map
+  // to compute the interchange vector.
+  for (auto dimsIdx : llvm::seq<int64_t>(0, rank))
+    if (dimsAndPosMapping.count(dimsIdx))
+      inVec.push_back(dimsAndPosMapping[dimsIdx]);
+
+  return inVec;
+}
+
+/// Builds an `affine_min` of `v1` and `v2`.
+static OpFoldResult buildMin(OpBuilder &b, Location loc, OpFoldResult v1,
+                             OpFoldResult v2) {
+  return makeComposedFoldedAffineMin(
+      b, loc, AffineMap::getMultiDimIdentityMap(2, loc.getContext()), {v1, v2});
+}
+
+/// Builds an `affine_apply` which subtracts `v2` from `v1`, i.e., "v1 - v2".
+static OpFoldResult buildSub(OpBuilder &b, Location loc, OpFoldResult v1,
+                             OpFoldResult v2) {
+  AffineExpr dim0, dim1;
+  bindDims(loc.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  return makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
+}
+
+/// Builds an `affine_apply` which takes the product of `v1` and `v2`. Note that
+/// v1 is dim expr and `v2` is symbol expr.
+static OpFoldResult buildMul(OpBuilder &b, Location loc, OpFoldResult v1,
+                             OpFoldResult v2) {
+  MLIRContext *ctx = loc.getContext();
+  AffineExpr i, tile;
+  bindDims(ctx, i);
+  bindSymbols(ctx, tile);
+  return makeComposedFoldedAffineApply(b, loc, i * tile,
+                                       ArrayRef<OpFoldResult>{v1, v2});
+}
+
+struct PackOpTiling
+    : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
+
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    // Note that here we only consider untiled dimensions and outer tiled data
+    // dimensions, the inner tiled data dimensions are materialized when
+    // building the body of the operation.
+    auto packOp = cast<PackOp>(op);
+    SmallVector<utils::IteratorType> iteratorTypes(
+        packOp.getSourceRank(), utils::IteratorType::parallel);
+    return iteratorTypes;
+  }
+
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    OpBuilder::InsertionGuard guard(b);
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+    int64_t rank = packOp.getSourceRank();
+    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+    ReifiedRankedShapedTypeDims resultShape;
+    (void)packOp.reifyResultShapes(b, resultShape);
+    SmallVector<Range> loopRanges(rank);
+    for (auto dim : llvm::seq<int64_t>(0, rank)) {
+      loopRanges[dim].offset = zero;
+      loopRanges[dim].stride = one;
+      loopRanges[dim].size = resultShape[0][dim];
+    }
+    return loopRanges;
+  }
+
+  SmallVector<Operation *>
+  getTiledImplementation(Operation *op, OpBuilder &b,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes) const {
+    auto packOp = cast<PackOp>(op);
+    Location loc = packOp.getLoc();
+
+    // The tiling is applied on interchanged dimensions. We have to undo the
+    // interchange to map sizes and offsets to the original input.
+    int64_t inputRank = packOp.getSourceRank();
+    ArrayRef<int64_t> dimsToOuterBlock(packOp.getOuterDimsPerm());
+    SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
+    SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
+    if (!dimsToOuterBlock.empty()) {
+      SmallVector<int64_t> vec =
+          computeInversePermutationForDimsPos(dimsToOuterBlock, inputRank);
+      undoPermutationToVector<OpFoldResult>(origOffsets, vec);
+      undoPermutationToVector<OpFoldResult>(origSizes, vec);
+    }
+
+    DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+        packOp.getDimAndTileMapping();
+    SmallVector<OpFoldResult> srcDimValues =
+        tensor::createDimValues(b, loc, packOp.getSource());
+    SmallVector<OpFoldResult> inputIndices, inputSizes;
+    for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+      if (dimAndTileMapping.count(dim)) {
+        // If the data dimension is tiled, the i-th index is the product of
+        // offset_i and tile_i, and the i-th size is the product of sizes_i and
+        // tile_i.
+        inputIndices.push_back(
+            buildMul(b, loc, origOffsets[dim], dimAndTileMapping[dim]));
+        inputSizes.push_back(
+            buildMul(b, loc, origSizes[dim], dimAndTileMapping[dim]));
+      } else {
+        inputIndices.push_back(origOffsets[dim]);
+        inputSizes.push_back(origSizes[dim]);
+      }
+
+      // Limit the size of the input operand for incomplete tiles.
+      OpFoldResult dimSize = srcDimValues[dim];
+      inputSizes.back() =
+          buildMin(b, loc, inputSizes.back(),
+                   buildSub(b, loc, dimSize, inputIndices.back()));
+    }
+
+    auto oneAttr = b.getI64IntegerAttr(1);
+    SmallVector<OpFoldResult> strides(inputRank, oneAttr);
+
+    SmallVector<Value> tiledOperands;
+    tiledOperands.push_back(b.create<ExtractSliceOp>(
+        loc, packOp.getSource(), inputIndices, inputSizes, strides));
+
+    SmallVector<OpFoldResult> outputOffsets, outputSizes;
+    if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
+                                     outputSizes)))
+      return {};
+
+    strides.append(packOp.getDestRank() - inputRank, oneAttr);
+    auto extractSlice = b.create<ExtractSliceOp>(
+        loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+    tiledOperands.push_back(extractSlice);
+
+    if (auto val = packOp.getPaddingValue())
+      tiledOperands.push_back(val);
+    for (auto tile : packOp.getInnerTiles())
+      tiledOperands.push_back(tile);
+
+    Operation *tiledPackOp = b.create<PackOp>(
+        loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+
+    return {tiledPackOp};
+  }
+
+  LogicalResult
+  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+                        ArrayRef<OpFoldResult> offsets,
+                        ArrayRef<OpFoldResult> sizes,
+                        SmallVector<OpFoldResult> &resultOffsets,
+                        SmallVector<OpFoldResult> &resultSizes) const {
+    // The iteration domain is over outer dimensions of packed layout. In this
+    // context, the outer dimensions of `resultOffsets` are `offsets`. The
+    // inner dimensions of `resultOffsets` are zeros because tiling is not
+    // applied to them.
+    auto packOp = cast<PackOp>(op);
+    int64_t inputRank = packOp.getSourceRank();
+    int64_t outputRank = packOp.getDestRank();
+    auto zeroAttr = b.getI64IntegerAttr(0);
+    resultOffsets.assign(offsets.begin(), offsets.end());
+    resultOffsets.append(outputRank - inputRank, zeroAttr);
+
+    ReifiedRankedShapedTypeDims outputShape;
+    (void)packOp.reifyResultShapes(b, outputShape);
+    resultSizes.assign(sizes.begin(), sizes.end());
+    for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
+      resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim]));
+
+    return success();
+  }
+};
+
 } // namespace
 
 Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
@@ -282,5 +472,6 @@
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
+    tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -0,0 +1,214 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s
+
+// CHECK-DAG:   #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 32)>
+// CHECK-DAG:   #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -32 + 128, 64)>
+// CHECK-DAG:   #[[MAP2:.+]] = affine_map<(d0) -> (d0 * -32 + 256, 128)>
+// CHECK:       func.func @NC_to_NCnc
+// CHECK-SAME:    %[[IN:.*]]: tensor<128x256xf32>,
+// CHECK-SAME:    %[[OUT:.*]]: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[RES0:.*]] = scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<4x8x32x32xf32>) {
+// CHECK:           %[[RES1:.+]] = scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<4x8x32x32xf32>) {
+// CHECK-DAG:         %[[IN_N:.+]] = affine.apply #[[MAP0]](%[[N]])
+// CHECK-DAG:         %[[IN_N_SZ:.*]] = affine.min #[[MAP1]]
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.*]] = affine.min #[[MAP2]]
+// CHECK:             %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_N]], %[[IN_C]]] [%[[IN_N_SZ]], %[[IN_C_SZ]]] [1, 1] : tensor<128x256xf32> to tensor<?x?xf32>
+// CHECK:             %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[N]], %[[C]], 0, 0] [2, 4, 32, 32] [1, 1, 1, 1] : tensor<4x8x32x32xf32> to tensor<2x4x32x32xf32>
+// CHECK:             %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
+// CHECK:             %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:          %[[SUB_IN]] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %[[CAST_OUT]]
+// CHECK:             %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
+// CHECK:             %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK:             scf.yield %[[INSERT]] : tensor<4x8x32x32xf32>
+// CHECK:           }
+// CHECK:           scf.yield %[[RES1:.*]] : tensor<4x8x32x32xf32>
+// CHECK:         }
+// CHECK:         return %[[RES0:.*]] : tensor<4x8x32x32xf32>
+// CHECK:       }
+func.func @NC_to_NCnc(%arg0: tensor<128x256xf32>, %arg1: tensor<4x8x32x32xf32>) -> tensor<4x8x32x32xf32> {
+  %0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32>
+  return %0 : tensor<4x8x32x32xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK:       #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK:       #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -8 + 256, 16)>
+// CHECK:       func.func @KC_to_CKkc
+// CHECK-SAME:    %[[IN:[A-Za-z0-9]+]]:
+// CHECK-SAME:    %[[OUT:[A-Za-z0-9]+]]:
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C32:.+]] = arith.constant 32 : index
+// CHECK:         scf.for %[[C:.+]] = %[[C0]] to %[[C32]] step %[[C2]]
+// CHECK-DAG:         %[[IN_C:.+]] = affine.apply #[[MAP0]](%[[C]])
+// CHECK-DAG:         %[[IN_C_SZ:.+]] = affine.min #[[MAP1]](%[[C]])
+// CHECK:             %[[INPUT_SLICE:.+]] = tensor.extract_slice %[[IN]]
+// CHECK-SAME:          [0, %[[IN_C]]] [128, %[[IN_C_SZ]]]
+// CHECK:             %[[CAST_IN:.+]] = tensor.cast %[[INPUT_SLICE]]
+// CHECK:             %[[OUTPUT_SLICE:.+]] = tensor.extract_slice %{{.+}}[%[[C]], 0, 0, 0] [2, 4, 32, 8]
+// CHECK:             %[[CAST_OUT:.+]] = tensor.cast %[[OUTPUT_SLICE]]
+// CHECK:             tensor.pack
+// CHECK-SAME:          %[[CAST_IN]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+// CHECK-SAME:          into %[[CAST_OUT]]
+func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>) -> tensor<32x4x32x8xf32> {
+  %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<128x256xf32> -> tensor<32x4x32x8xf32>
+  return %0 : tensor<32x4x32x8xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0) -> (d0 * -2 + 15, 8)>
+// CHECK:         func.func @pad_and_pack_static(
+// CHECK-SAME:      %[[IN:.*]]: tensor<13x15xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<2x8x8x2xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32) -> tensor<2x8x8x2xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[RES0:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[OUT]]) -> (tensor<2x8x8x2xf32>) {
+// CHECK-DAG:         %[[IN_J:.*]] = affine.apply #[[MAP0]](%[[J]])
+// CHECK-DAG:         %[[IN_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])
+// CHECK:             %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][0, %[[IN_J]]] [13, %[[IN_J_SZ]]] [1, 1]
+// CHECK:             %[[CAST_IN:.*]] = tensor.cast %[[SUB_IN]]
+// CHECK:             %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][0, %[[J]], 0, 0] [2, 4, 8, 2] [1, 1, 1, 1]
+// CHECK:             %[[CAST_OUT:.*]] = tensor.cast %[[SUB_OUT]]
+// CHECK:             %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:          %[[CAST_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:          into %[[CAST_OUT]]
+// CHECK:             %[[CAST_RES:.*]] = tensor.cast %[[SUB_RES]]
+// CHECK:             %[[INSERT:.*]] = tensor.insert_slice %[[CAST_RES]] into %[[ITER1]]
+// CHECK:             scf.yield %[[INSERT]] : tensor<2x8x8x2xf32>
+// CHECK:           }
+// CHECK:           return %[[RES0:.*]] : tensor<2x8x8x2xf32>
+// CHECK:         }
+func.func @pad_and_pack_static(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: f32) -> tensor<2x8x8x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
+  return %0 : tensor<2x8x8x2xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG:     #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 8)>
+// CHECK-DAG:     #[[MAP3:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -8 + s0, d0 * 8)>
+// CHECK-DAG:     #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-DAG:     #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1 * -2 + s0, d0 * 2)>
+// CHECK:         func.func @pad_and_pack_partially_dynamic(
+// CHECK-SAME:      %[[IN:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<?x?x8x2xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32) -> tensor<?x?x8x2xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x8x2xf32>
+// CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:         %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x8x2xf32>) {
+// CHECK-DAG:           %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK-DAG:           %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])
+// CHECK-DAG:           %[[IN_I_SZ:.*]] = affine.min #[[MAP3]]
+// CHECK-DAG:           %[[IN_J:.*]] = affine.apply #[[MAP4]](%[[J]])
+// CHECK-DAG:           %[[IN_J_SZ:.*]] = affine.min #[[MAP5]]
+// CHECK:               %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK:               %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], 8, 2] [1, 1, 1, 1] : tensor<?x?x8x2xf32> to tensor<?x?x8x2xf32>
+// CHECK:               %[[SUB_RES:.*]] = tensor.pack
+// CHECK-SAME:            %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:            into %[[SUB_OUT]]
+// CHECK:               %[[INSERT:.*]] = tensor.insert_slice %[[SUB_RES]] into %[[ITER1]]
+// CHECK:               scf.yield %[[INSERT]] : tensor<?x?x8x2xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[RES1:.*]] : tensor<?x?x8x2xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_34:.*]] : tensor<?x?x8x2xf32>
+// CHECK:         }
+func.func @pad_and_pack_partially_dynamic(%input: tensor<?x?xf32>, %output: tensor<?x?x8x2xf32>, %pad: f32) -> tensor<?x?x8x2xf32> {
+  %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+  return %0 : tensor<?x?x8x2xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
+
+// -----
+
+// CHECK-DAG:     #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
+// CHECK-DAG:     #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK-DAG:     #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG:     #[[MAP3:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0, -(d1 * s0) + s1)>
+// CHECK:         func.func @pad_and_pack_fully_dynamic(
+// CHECK-SAME:      %[[IN:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUT:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32,
+// CHECK-SAME:      %[[TILE_0:.*]]: index,
+// CHECK-SAME:      %[[TILE_1:.*]]: index) -> tensor<?x?x?x?xf32> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:       %[[OUT_D0:.*]] = tensor.dim %[[OUT]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG:       %[[OUT_D1:.*]] = tensor.dim %[[OUT]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK:           %[[RES0:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[OUT_D0]] step %[[C2]] iter_args(%[[ITER0:.*]] = %[[OUT]]) -> (tensor<?x?x?x?xf32>) {
+// CHECK:             %[[OUT_I_SZ:.*]] = affine.min #[[MAP0]](%[[I]])[%[[OUT_D0]]]
+// CHECK:             %[[RES1:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[OUT_D1]] step %[[C4]] iter_args(%[[ITER1:.*]] = %[[ITER0]]) -> (tensor<?x?x?x?xf32>) {
+// CHECK:               %[[OUT_J_SZ:.*]] = affine.min #[[MAP1]](%[[J]])[%[[OUT_D1]]]
+// CHECK:               %[[IN_D0:.*]] = tensor.dim %[[IN]], %[[C0]]
+// CHECK:               %[[IN_D1:.*]] = tensor.dim %[[IN]], %[[C1]]
+// CHECK:               %[[IN_I:.*]] = affine.apply #[[MAP2]](%[[I]])[%[[TILE_0]]]
+// CHECK:               %[[IN_I_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_I_SZ]], %[[I]])[%[[TILE_0]], %[[IN_D0]]]
+// CHECK:               %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
+// CHECK:               %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]]
+// CHECK:               %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK:               %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]]
+// CHECK:               %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]]
+// CHECK:               %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK:               %[[PACK:.*]] = tensor.pack
+// CHECK-SAME:            %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]]
+// CHECK-SAME:            into %[[SUB_OUT]]
+// CHECK:               %[[INSERT:.*]] = tensor.insert_slice %[[PACK]] into %[[ITER1]]
+// CHECK:               scf.yield %[[INSERT]] : tensor<?x?x?x?xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[RES1:.*]] : tensor<?x?x?x?xf32>
+// CHECK:           }
+// CHECK:           return %[[RES0:.*]] : tensor<?x?x?x?xf32>
+// CHECK:         }
+func.func @pad_and_pack_fully_dynamic(%source: tensor<?x?xf32>, %dest: tensor<?x?x?x?xf32>, %pad: f32, %tile_n : index, %tile_m : index) -> tensor<?x?x?x?xf32> {
+  %0 = tensor.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg1
+    %1, %loops:2 = transform.structured.tile_to_scf_for %0 [2, 4]
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5373,10 +5373,12 @@
     deps = [
         ":AffineDialect",
         ":ArithUtils",
+        ":DialectUtils",
         ":IR",
         ":LinalgDialect",
         ":SCFDialect",
         ":TensorDialect",
+        ":TensorUtils",
         ":TilingInterface",
         "//llvm:Support",
     ],