This is an archive of the discontinued LLVM Phabricator instance.

[mlir][linalg] Fix incorrect bound calculation for tiling conv
ClosedPublic

Authored by antiagainst on Sep 30 2021, 9:40 AM.

Details

Summary

For convolution, the input window dimension's access affine map
is of the form (d0 * s0 + d1), where d0/d1 is the output/
filter window dimension, and s0 is the stride.

When tiling, https://reviews.llvm.org/D109267 changed how the
way dimensions are acquired. Instead of directly querying using
*.dim ops on the original convolution op, we now get it by
applying the access affine map to the loop upper bounds. This
is fine for dimensions having single-dimension affine maps,
like matmul, but not for convolution input. It will cause
incorrect compuation and out of bound. A concrete example, say
we have 1x225x225x3 (NHWC) input, 3x3x3x32 (HWCF) filter, and
1x112x112x3 (NHWC) output with stride 2, (112 * 2 + 3) would be
227, which is different from the correct input window dimension
size 225.

Instead, we should first calculate the max indices for each loop,
and apply the affine map to them, and then plus one to get the
dimension size. Note this makes no difference for matmul-like
ops given they will have d0 - 1 + 1 effectively.

Event Timeline

antiagainst created this revision.Sep 30 2021, 9:40 AM
antiagainst requested review of this revision.Sep 30 2021, 9:40 AM
nicolasvasilache accepted this revision.Sep 30 2021, 9:45 AM

Yes, exclusive -> inclusive -> apply -> exclusive is the proper way, thanks!

Just note however that linalg.conv is going to be retired soon.

This revision is now accepted and ready to land.Sep 30 2021, 9:45 AM

also @gysit who is looking at retring conv.

Just note however that linalg.conv is going to be retired soon.

+100. It needs to be cleaned up for a long time! Thanks for that, @gysit!

gysit added inline comments.Sep 30 2021, 11:14 AM
mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
237

@antiagainst good catch and thanks for fixing. 10 + 8 is not always =18 :). I hope this second 18 disappears if I adapt the sizes of %arg0. Will have a look at it tomorrow. Otherwise, there may be a similar error somewhere.