Add a pattern to the pass that simplifies extract_strided_metadata(other_op(memref)).
The new pattern gets rid of the expand_shape operation while materializing its effects on the sizes, and the strides of the base object.
In other words, this simplification replaces:
baseBuffer, offset, sizes, strides = extract_strided_metadata(expand_shape(memref))
With
baseBuffer, baseOffset, baseSizes, baseStrides = extract_strided_metadata(memref) sizes#reassIdx = baseSizes#reassDim / product(expandShapeSizes#j, for j in group excluding reassIdx) strides#reassIdx = baseStrides#reassDim * product(expandShapeSizes#j, for j in reassIdx+1.. reassIdx+group.size-1)
Where reassIdx is a reassociation index for the group at reassDim and expandShapeSizes#j is either:
- The constant size at dimension j, derived directly from the result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all constant sizes in expandShapeSizes.
nit: typo