diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -549,31 +549,38 @@ "require inner tile sizes being static"); } - // 1. Use rank-reduced tensor.extract_slice op to extract the tile. + // 1. Use rank-reduced tensor.extract_slice op to extract the tile. This also + // creates a tensor.cast op right before the rank-reduced + // tensor.extract_slice. This is a known information because all the outer + // dims are 1s in packed domain; it is extremely helpful for other analysis. Location loc = packOp.getLoc(); Attribute zeroIdxAttr = rewriter.getIndexAttr(0); Attribute oneIdxAttr = rewriter.getIndexAttr(1); SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); SmallVector readSizes; - SmallVector readShape; + SmallVector readShape, castShape; DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); for (auto i : llvm::seq(0, srcRank)) { if (!dimAndTileMapping.count(i)) { readSizes.push_back(oneIdxAttr); + castShape.push_back(1); continue; } readSizes.push_back(dimAndTileMapping[i]); readShape.push_back(getConstantIntValue(dimAndTileMapping[i]) .value_or(ShapedType::kDynamic)); + castShape.push_back(readShape.back()); } Type elemType = packOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShape, elemType); Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); + Value cast = rewriter.create( + loc, RankedTensorType::get(castShape, elemType), input); Value tile = rewriter.create( - loc, readType, input, readOffsets, readSizes, readStrides); + loc, readType, cast, readOffsets, readSizes, readStrides); // 2. Transpose the tile to match the inner tile order. SmallVector perm = diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir @@ -19,8 +19,9 @@ // CHECK: %[[IN_S_SZ:.+]] = affine.min #[[MAP3]](%[[S]]) // CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] // CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, %[[IN_R_SZ]], %[[IN_S_SZ]]] [1, 1, 1, 1] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x?x?xf32> to tensor<32x8xf32> +// CHECK: %[[CAST:.+]] = tensor.cast %[[SRC_SLICE]] : tensor<1x1x?x?xf32> to tensor<1x1x32x8xf32> +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[CAST]] +// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> // CHECK: %[[TRANSP:.+]] = linalg.transpose // CHECK-SAME: ins(%[[TILE]] @@ -85,8 +86,7 @@ // CHECK-DAG: %[[IN_C_SZ:.+]] = affine.min #[[MAP3]](%[[C]]) // CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] // CHECK-SAME: [%[[IN_K]], %[[IN_C]]] [%[[IN_K_SZ]], %[[IN_C_SZ]]] [1, 1] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-SAME: [0, 0] [32, 8] [1, 1] : tensor to tensor<32x8xf32> +// CHECK: %[[TILE:.+]] = tensor.cast %[[SRC_SLICE]] : tensor to tensor<32x8xf32> // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x8xf32> // CHECK: %[[TRANSP:.+]] = linalg.transpose // CHECK-SAME: ins(%[[TILE]]