Page MenuHomePhabricator

[mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape
ClosedPublic

Authored by qcolombet on Oct 21 2022, 12:37 PM.

Details

Summary

collapse/expand_shape are supposed to be expanded before we hit the lowering code.
The expansion is done with the pass called expand-strided-metadata.

This patch is NFC in spirit but not in practice because expand-strided-metadata won't try to accomodate for "invalid" strides for dynamic sizes that are 1 at runtime.

The previous code was broken in that respect too, but differently: it handled only the case of row-major layouts.
That whole part is being reworked separately.

Diff Detail

Event Timeline

qcolombet created this revision.Oct 21 2022, 12:37 PM
qcolombet requested review of this revision.Oct 21 2022, 12:37 PM

Could you comment (potentially the commit description) why the length of tests in the generated IR increases significantly after what is said to be a simplification? Were the previous tests just checking a subset of IR operations that were actually generated?

Hi ,

Could you comment (potentially the commit description) why the length of tests in the generated IR increases significantly after what is said to be a simplification? Were the previous tests just checking a subset of IR operations that were actually generated?

Good point, I haven't repeated the message explaining that from https://reviews.llvm.org/D136377.

Here is the relevant part:

This patch is NFC in spirit but not in practice because subview [here expand/collapse_shape] gets lowered into reinterpret_cast(extract_strided_metadata, <some math>) which lowers in two memref descriptors (one for reinterpert_cast and one for extract_strided_metadata), which creates some noise of the form: extractvalue(unrealized_cast(extractvalue[0]))[0] that is currently not simplified within MLIR but that is really just noop in that case.

Note: This patch builds on top of https://reviews.llvm.org/D136377, so it suffers the same problem as that one, i.e., the affine-to-std and arith-to-llvm dependencies.

As far as the term simplification goes, I should really say expansion. I'll fix that.

Cheers,
-Quentin

Oh and I forgot.
Another reason for having more IR is because we kept each affine expression independent to each other and some terms may be repeated between two expressions.
E.g.,

newSize = oldSize1 * oldSize2
finalOffset = oldOffset * oldSize1 * oldSize2

Here oldSize1 * oldSize2 could be reused (or CSE'd) but is currently expanded twice (once for each affine.apply expression.)

So if I understand correctly, we are now emitting more arithmetics that we are unable to simplify at the MLIR level. This sounds concerning and seem to defy the purpose of using affine maps, which should be the easy-to-compose closed-form expressions. Do we expect LLVM's CSE to simplify this? Is there any indication of this actually happening or not? (In one my previous projects, we've seen a performance improvement attributable to better/simpler address generation via memrefs, which we may now undo...)

nicolasvasilache accepted this revision.Oct 27 2022, 8:29 AM

This is similar to https://reviews.llvm.org/D136377, feel free to just land this with a similar solution once the first one is agreed on and landed.

This revision is now accepted and ready to land.Oct 27 2022, 8:29 AM

So if I understand correctly, we are now emitting more arithmetics that we are unable to simplify at the MLIR level.

That's correct.

This sounds concerning and seem to defy the purpose of using affine maps, which should be the easy-to-compose closed-form expressions.

When I first wrote this lowering (see https://reviews.llvm.org/D133166 for the details.), we started with doing compositions, but I found that it made the resulting IR hard to reason about. Now that I am more familiar with MLIR and affine maps in particular, maybe this is not as bad. I.e., we could revert this decision if that cause any problem.

Do we expect LLVM's CSE to simplify this?

Yes, I would expect LLVM's CSE to pick it up since we are dealing with simple math operations.

Is there any indication of this actually happening or not?

I haven't actually checked. I'll do that.
How do you go from the llvm dialect to actual LLVM IR?

That said, this is an interesting issue. Do we expect MLIR to produce the most optimized/concise code as possible or do/can we rely on the lower layers to do the cleanups?

If this is the former, then for instance why are we lowering dead code to beginning with?

(In one my previous projects, we've seen a performance improvement attributable to better/simpler address generation via memrefs, which we may now undo...)

Let's double check if the CSE happens right now or not.

Reporting on this:

Do we expect LLVM's CSE to simplify this? Is there any indication of this actually happening or not?

Yes, I confirmed that CSE is happening just fine.

Here is what I did:
Old mlir-opt:

mlir-opt -convert-memref-to-llvm -lower-affine -convert-arith-to-llvm  -convert-func-to-llvm -reconcile-unrealized-casts <input>.mlir -o  <output>.mlir
mlir-translate -mlir-to-llvmir  <output>.mlir -o - | opt -S -early-cse -o old-static.ll

New mlir-opt, i.e., with this patch:

# Run the expand pass first (right now it is called simplify-extract-strided-metadata
mlir-opt -simplify-extract-strided-metadata -convert-memref-to-llvm -lower-affine -convert-arith-to-llvm  -convert-func-to-llvm -reconcile-unrealized-casts <input>.mlir -o  <output>.mlir
mlir-translate -mlir-to-llvmir  <output>.mlir -o - | opt -S -early-cse -o new-static.ll

Result the IR is semantically equivalent and as performant in both case. The only difference is the extract_strided_metadata descriptor that stays around is we don't run DCE.
E.g., with the function collapse_shape_static from in memref-to-llvm.mlir:

--- old-static-cse.ll   2022-11-05 01:35:30.604898681 +0000
+++ new-static-cse.ll   2022-11-05 01:35:24.384293356 +0000
@@ -1,36 +1,38 @@
 ; ModuleID = '<stdin>'
 source_filename = "LLVMDialectModule"
 
 declare ptr @malloc(i64)
 
 declare void @free(ptr)
 
 define { ptr, ptr, i64, [3 x i64], [3 x i64] } @collapse_shape_static(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12) {
   %14 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } undef, ptr %0, 0
   %15 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %14, ptr %1, 1
   %16 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %15, i64 %2, 2
   %17 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %16, i64 %3, 3, 0
   %18 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %17, i64 %8, 4, 0
   %19 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %18, i64 %4, 3, 1
   %20 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %19, i64 %9, 4, 1
   %21 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %20, i64 %5, 3, 2
   %22 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %21, i64 %10, 4, 2
   %23 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %22, i64 %6, 3, 3
   %24 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %23, i64 %11, 4, 3
   %25 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %24, i64 %7, 3, 4
   %26 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %25, i64 %12, 4, 4
-  %27 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
-  %28 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %27, ptr %1, 1
-  %29 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %28, i64 %2, 2
-  %30 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %29, i64 3, 3, 0
-  %31 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %30, i64 4, 3, 1
-  %32 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %31, i64 5, 3, 2
+  %27 = insertvalue { ptr, ptr, i64 } undef, ptr %0, 0
+  %28 = insertvalue { ptr, ptr, i64 } %27, ptr %1, 1
+  %29 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
+  %30 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %29, ptr %1, 1
+  %31 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %30, i64 0, 2
+  %32 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %31, i64 3, 3, 0
   %33 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %32, i64 20, 4, 0
-  %34 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %33, i64 5, 4, 1
-  %35 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %34, i64 1, 4, 2
-  ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %35
+  %34 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %33, i64 4, 3, 1
+  %35 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %34, i64 5, 4, 1
+  %36 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %35, i64 5, 3, 2
+  %37 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %36, i64 1, 4, 2
+  ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %37
 }
 
 !llvm.module.flags = !{!0}
 
 !0 = !{i32 2, !"Debug Info Version", i32 3}

And it gets more evident that things are the same if you run instcombine instead of cse:

--- old-static-instcombine.ll   2022-11-05 02:07:09.937594426 +0000
+++ new-static-instcombine.ll   2022-11-05 02:07:17.966374824 +0000
@@ -1,23 +1,23 @@
 ; ModuleID = '<stdin>'
 source_filename = "LLVMDialectModule"
 
 declare ptr @malloc(i64)
 
 declare void @free(ptr)
 
 define { ptr, ptr, i64, [3 x i64], [3 x i64] } @collapse_shape_static(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12) {
   %14 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
   %15 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %14, ptr %1, 1
-  %16 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %15, i64 %2, 2
+  %16 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %15, i64 0, 2
   %17 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %16, i64 3, 3, 0
-  %18 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %17, i64 4, 3, 1
-  %19 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %18, i64 5, 3, 2
-  %20 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %19, i64 20, 4, 0
-  %21 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %20, i64 5, 4, 1
+  %18 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %17, i64 20, 4, 0
+  %19 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %18, i64 4, 3, 1
+  %20 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %19, i64 5, 4, 1
+  %21 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %20, i64 5, 3, 2
   %22 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %21, i64 1, 4, 2
   ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %22
 }
 
 !llvm.module.flags = !{!0}
 
 !0 = !{i32 2, !"Debug Info Version", i32 3}

BTW, regarding the dynamic case, I noticed that the new code and the old code differ.

Here is the ir for collapse_shape_dynamic_with_non_identity_layout, after renaming the variables and reordering them for clearer diff:

--- old.ll      2022-11-05 06:19:12.140322531 +0000
+++ new.ll      2022-11-05 06:22:09.989617895 +0000
@@ -1,31 +1,23 @@
 ; ModuleID = '<stdin>'
 source_filename = "LLVMDialectModule"
 
 declare ptr @malloc(i64)
 
 declare void @free(ptr)
 
 define { ptr, ptr, i64, [2 x i64], [2 x i64] } @collapse_shape_dynamic_with_non_identity_layout(ptr %arg, ptr %arg1, i64 %arg2, i64 %arg3, i64 %arg4, i64 %arg5, i64 %arg6, i64 %arg7, i64 %arg8) {
 bb:
-  %test.not = icmp eq i64 %arg5, 1
-  br i1 %test.not, label %bb34, label %bb36
-
-bb34:                                             ; preds = %bb
-  br label %bb36
-
-bb36:                                             ; preds = %bb34, %bb
-  %stride1 = phi i64 [ %arg7, %bb34 ], [ %arg8, %bb ]
   %desc = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } undef, ptr %arg, 0
   %desc0 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc, ptr %arg1, 1
   %desc1 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc0, i64 %arg2, 2
   %desc2 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc1, i64 4, 3, 0
   %size1 = mul i64 %arg4, %arg5
   %desc3 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc2, i64 %size1, 3, 1
   %desc4 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc3, i64 %arg6, 4, 0
-  %desc5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc4, i64 %stride1, 4, 1
+  %desc5 = insertvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc4, i64 1, 4, 1
   ret { ptr, ptr, i64, [2 x i64], [2 x i64] } %desc5
 }
 
 !llvm.module.flags = !{!0}
 
 !0 = !{i32 2, !"Debug Info Version", i32 3}

The new code code ditches the check for the dimensions of size one.
That's interesting because the old code was trying to find a "better" stride for the collapsed dimensions, but I believe this is not generally correct and if I understand the "specifications" of collapse shape, this shouldn't be required at all:

Collapsing non-contiguous dimensions is undefined behavior.

Put differently if we collapse dimensions that are not contiguous, unless I miss something, that means that our stride would have to go through gaps within the same dimension, which we don't support right now. Hence, I don't think the old code was doing the right thing here, but it's late, it's Friday and I may not be thinking clear anymore :).

That code with the check for stride == 1 comes from https://reviews.llvm.org/D124001.

@cathyzhyi, @springerm, what do you think?
CC'ing: @ftynse and @nicolasvasilache to have a better "letter of the law" semantic of the collapse_shape :).

Collapsing non-contiguous dimensions is undefined behavior.

Non-contiguous dimensions of size 1 can be collapsed. But for dims of size 1 it doesn't really make sense to have a stride in the first place; so such a stride could be simplified, e.g., to just 1. Is that what's happening here?

(Collapsing non-contiguous dimensions of size >1 should crash at runtime, but we don't generate the assert at the moment.)

@springerm thanks for your answer.

But for dims of size 1 it doesn't really make sense to have a stride in the first place; so such a stride could be simplified, e.g., to just 1. Is that what's happening here?

With the new code what is happening is when we collapse dimensions, we take the stride of the innermost dimension as the stride of the whole collapsed dimension, since the underlying tensor is supposed to be contiguous.
e.g.,

collapse_shape <5x?x?x?x?xi16, strided<[?, ?, ?, ?, ?]>>, [[0, 1], [2, 3, 4]] -> <?x?xi16>

>

dim(0) == 5 x orig_shape.dim(1)
dim(1) == orig_shape.dim(2) x  orig_shape.dim(3) x orig_shape.dim(4)
----
stride(0) == orig_shape.stride(1)
stride(1) == orig_shape.stride(4)

Now, this doesn't match what we were doing before this patch where in codegen we would explicitly skip the dimensions with size 1.
I.e., the old version would generate:

dim(0) == 5 x orig_shape.dim(1)
dim(1) == orig_shape.dim(2) x  orig_shape.dim(3) x orig_shape.dim(4)
----
stride(0) == orig_shape.dim(1) != 1? orig_shape.stride(1) : orig_shape.stride(0)
stride(1) == orig_shape.dim(4) != 1?
    orig_shape.stride(4) :
    (orig_shape.dim(3) != 1?
        orig_shape.stride(3):
        orig_shape.stride(2))

I understand that strides of dimensions of size 1 don't really make sense, but I found the old generated code to paper over something that is underspecified. Put differently, it actively harms codegen to have to ignore strides for dimension of size 1 and I was wondering if it is intended.

At my first reading of the spec, I was expecting that strides should be contiguous even for size 1 dimensions.

To take a concrete example from @collapse_shape_dynamic_with_non_identity_layout:

%0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
  memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
  memref<4x?xf32, strided<[?, ?], offset: ?>>

Here I was expecting that it is okay to infer that the inner most stride (stride(1)) is one, since we collapse dimensions with respectively strides 4 and 1. And if that's not true we are in the undefined behavior realm.
However with what you're saying, the stride could be either 1 or 4 depending whether orig_shape.dim(2) == 1 or not.

Cheers,
-Quentin

A problematic case would be:

memref<3x1xi16, strided<[8, 1]>>

This memref can be collapsed to memref<3xi16, strided<[8]>>. Using stride 1 (as your new code computes) would be incorrect here, we have to skip the stride of the dim of size 1.

Note, the code that you are looking at is for the "dynamic" case, where strides etc. must be computed at runtime based on data in the memref descriptor. We also have a "static" case, where we infer/verify the static layout map in case there are no ? strides and/or dims. (If there are a few ?, only part of the layout map can be inferred/verified.) This code path mirrors the code path that you are looking at and is probably easier to understand and experiment with than the dynamic case. But it implements (or at least should implement) the same logic.

In particular, there is this comment in computeCollapsedLayoutMap (MemRefOps.cpp):

The result stride of a reassociation group is the stride of the last entry
of the reassociation. (...) Dimensions of size 1 should be skipped, because
their strides are meaningless and could have any arbitrary value.

I agree that many things related to strides, layout maps, etc. do not have good documentation, are maybe even underspecified. I fixed multiple bugs in expand_shape/collapse_shape due to this about a year ago. It is probably still not as good as it could be, so any improvements are appreciated!

qcolombet retitled this revision from [mlir][MemRefToLLVM] Reuse existing lowering for collaspe/expand_shape to [mlir][MemRefToLLVM] Remove the code for lowering collaspe/expand_shape.
qcolombet edited the summary of this revision. (Show Details)
  • Rebase
  • Move the expand/collapse_shape tests in expand-then-convert-to-llvm.mlir, since now they require to run the expansion pass beforehand
  • Update PR description
  • Run clang-format

Now that we confirm that memref.reinterpret_cast does what we needed (see https://github.com/llvm/llvm-project/issues/59896), I feel confortable moving forward with this patch again.

@ftynse what do you think?

ftynse accepted this revision.Fri, Jan 20, 2:17 AM