Add a pattern to the pass that simplifies extract_strided_metadata(other_op(memref)).
The new pattern gets rid of the expand_shape operation while materializing its effects on the sizes, and the strides of the base object.
In other words, this simplification replaces:
baseBuffer, offset, sizes, strides =
extract_strided_metadata(expand_shape(memref))With
baseBuffer, baseOffset, baseSizes, baseStrides =
extract_strided_metadata(memref)
sizes#reassIdx =
baseSizes#reassDim / product(expandShapeSizes#j,
for j in group excluding
reassIdx)
strides#reassIdx =
baseStrides#reassDim * product(expandShapeSizes#j,
for j in
reassIdx+1..
reassIdx+group.size-1)Where reassIdx is a reassociation index for the group at reassDim and expandShapeSizes#j is either:
- The constant size at dimension j, derived directly from the result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all constant sizes in expandShapeSizes.
nit: typo