This change adds a set of utilities to replace the result of a
tensor.collapse_shape -> tensor.extract_slice chain with the
equivalent result formed by aggregating slices of the
tensor.collapse_shape source. In general, it is not possible to
commute extract_slice and collapse_shape if linearized dimensions
are sliced. The i-th dimension of the tensor.collapse_shape
result is a "linearized sliced dimension" if:
- Reassociation indices of tensor.collapse_shape in the i'th position is greater than size 1 (multiple dimensions of the input are collapsed)
- The i-th dimension is sliced by tensor.extract_slice.
We can work around this by stitching together the result of
tensor.extract_slice by iterating over any linearized sliced dimensions.
This is equivalent to "tiling" the linearized-and-sliced dimensions of
the tensor.collapse_shape operation in order to manifest the result
tile (the result of the tensor.extract_slice). The user of the
utilities must provide the mechanism to create the tiling (e.g. a loop).
In the tests, it is demonstrated how to apply the utilities using either
scf.for or scf.foreach_thread.
The below example illustrates the pattern using scf.for:
%0 = linalg.generic ... -> tensor<3x7x11x10xf32> %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32> %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
We can construct %2 by generating the following IR:
%dest = linalg.init_tensor() : tensor<10x10xf32> %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { // Step 1: Map this output idx (%iv) to a multi-index for the input (%3): %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) %3:3 = arith.delinearize_index %iv into (3, 7, 11) // Step 2: Extract the slice from the input %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : tensor<1x1x1x10xf32> into tensor<1x10xf32> // Step 3: Insert the slice into the destination %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : tensor<1x10xf32> into tensor<10x10xf32> scf.yield %6 : tensor<10x10xf32> }
The pattern was discussed in the RFC here: https://discourse.llvm.org/t/rfc-tensor-extracting-slices-from-tensor-collapse-shape/64034
nit: "the the"