This is an archive of the discontinued LLVM Phabricator instance.

[mlir][MemRef] Make `getDroppedDims` on `memref.subview` account for preserved unit dimensions.
Needs ReviewPublic

Authored by mravishankar on Aug 28 2023, 11:16 AM.

Details

Summary

When computing dropped dimensions, if any unit-dimension is preserved,
then that should not be accounted for in the dropped dimensions.

The side-effect of this is that there is some ambiguity in the
dimensions being dropped sometimes. It seems like the convention is
that the outermost dimensions are prefered to be dropped.

Fixes #60091

Diff Detail

Event Timeline

mravishankar created this revision.Aug 28 2023, 11:16 AM
Herald added a project: Restricted Project. · View Herald TranscriptAug 28 2023, 11:16 AM
mravishankar requested review of this revision.Aug 28 2023, 11:16 AM

@springerm This fixes https://github.com/llvm/llvm-project/issues/60091 , but it seems like we have some ambiguity in which dimensions get dropped when some unit-dimensions are preserved. From tests it seems like the preference is to have the outer dimensions dropped. This should be recorded somewhere. Suggestions?

Apart from some not-so-great lit tests, this change seems fine. https://github.com/openxla/iree/pull/14851 tests this in IREE and no errors from this.

springerm added a comment.EditedAug 29 2023, 8:27 AM

@springerm This fixes https://github.com/llvm/llvm-project/issues/60091 , but it seems like we have some ambiguity in which dimensions get dropped when some unit-dimensions are preserved. From tests it seems like the preference is to have the outer dimensions dropped. This should be recorded somewhere. Suggestions?

It seems to me that it is impossible to implement this function correctly because we don't know which dimension corresponds to which stride. What we have right now is best effort but we found cases where we can't be sure.

An even bigger concern to me are dynamic strides. We should do the exact same thing regardless of whether something is static or dynamic. The only difference is whether it can be done at compile time or running time. That would mean that the dropped dims cannot be determined at compile time and this function would need to bail out in case of dynamic strides (which is likely a big problem for existing transformations). Either that or maintain a (non-dynamic) mapping from dims to strides in the MemRefType, which would replace the entire getNumOccurences-based logic. Then there would also be no ambiguities wrt. to unit strides anymore.

@springerm This fixes https://github.com/llvm/llvm-project/issues/60091 , but it seems like we have some ambiguity in which dimensions get dropped when some unit-dimensions are preserved. From tests it seems like the preference is to have the outer dimensions dropped. This should be recorded somewhere. Suggestions?

It seems to me that it is impossible to implement this function correctly because we don't know which dimension corresponds to which stride. What we have right now is best effort but we found cases where we can't be sure.

We need to change the op-definition of operations that allow "rank-reducing" to take explicitly the dropped dimensions. Without that it is going to be hard.

An even bigger concern to me are dynamic strides. We should do the exact same thing regardless of whether something is static or dynamic. The only difference is whether it can be done at compile time or running time. That would mean that the dropped dims cannot be determined at compile time and this function would need to bail out in case of dynamic strides (which is likely a big problem for existing transformations). Either that or maintain a (non-dynamic) mapping from dims to strides in the MemRefType, which would replace the entire getNumOccurences-based logic. Then there would also be no ambiguities wrt. to unit strides anymore.

Its hard to make this argument in the abstract. The issue really comes down to, if you drop dimensions then corresponding static strides need to be dropped as well to be consistent, which is what the logic here is for. If it is dynamic, it actually doesn't matter. Its all ?s anyway. Whichever you drop, it statically is consistent, and things get adjusted automatically at runtime. In any case, that is an orthogonal issue. This change seems fine AFAICS.

springerm added a comment.EditedAug 30 2023, 2:58 AM

I found a test case that breaks with this change (and passed before):

// RUN: mlir-opt %s -generate-runtime-verification -expand-strided-metadata

func.func @static_case(%arg0: memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>>, %arg1: index, %idx: index) -> f32 {
  // Force the last dim to be rank-reduced by picking the first and second stride.
  %s = memref.subview %arg0[0, 0, 0] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>> to memref<1x1xf32, strided<[5, 6], offset: ?>>
  %l = memref.load %s[%idx, %idx] : memref<1x1xf32, strided<[5, 6], offset: ?>>
  return %l : f32
}

It fails in the verifier:

/usr/local/google/_blaze_springerm/9abb62e345f5287adac38e0018bb89c9/execroot/google3/blaze-out/k8-dbg/bin/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir.test.runfiles/google3/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir:37:8: error: expected result type with stride = 6 instead of 5 in dim = 0
  %s = memref.subview %arg0[0, 0, 0] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>> to memref<1x1xf32, strided<[5, 6], offset: ?>>
       ^
/usr/local/google/_blaze_springerm/9abb62e345f5287adac38e0018bb89c9/execroot/google3/blaze-out/k8-dbg/bin/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir.test.runfiles/google3/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir:37:8: note: see current operation: %1 = "memref.reinterpret_cast"(%0#0, %0#1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1, 1>, static_strides = array<i64: 6, 7>}> : (memref<f32>, index) -> memref<1x1xf32, strided<[5, 6], offset: ?>>

I found this test case when I was trying to construct two identical test cases where the only difference is that the strides are dynamic or static; in such a way that different dims are dropped based on whether a stride is dynamic or static. This would mean that by inserting a static->dynamic stride casts, I can make getDroppedDims, computeMemRefRankReductionMask (and its callers) behave differently and circumvent the getNumOccurances-based logic.


For reference, this is the dynamic case:

func.func @dynamic_case(%arg0: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %arg1: index, %idx: index) -> f32 {
  // Cannot decide rank-reduced dim based on strides, because they are all the same. Will fall back to "first dim is rank-reduced".
  %s = memref.subview %arg0[0, 0, 0] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> to memref<1x1xf32, strided<[?, ?], offset: ?>>
  %l = memref.load %s[%idx, %idx] : memref<1x1xf32, strided<[?, ?], offset: ?>>
  return %l : f32
}

Expands to:

func.func @dynamic_case(%arg0: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %arg1: index, %arg2: index) -> f32 {
  %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %arg0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index, index, index
  %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [1, 1], strides: [%strides#1, %strides#2] : memref<f32> to memref<1x1xf32, strided<[?, ?], offset: ?>>
  %0 = memref.load %reinterpret_cast[%arg2, %arg2] : memref<1x1xf32, strided<[?, ?], offset: ?>>
  return %0 : f32
}

(This shows the strides do not get adjusted at runtime. We always pick the leading dimensions, regardless of what the runtime strides are.)

I found a test case that breaks with this change (and passed before):

// RUN: mlir-opt %s -generate-runtime-verification -expand-strided-metadata

func.func @static_case(%arg0: memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>>, %arg1: index, %idx: index) -> f32 {
  // Force the last dim to be rank-reduced by picking the first and second stride.
  %s = memref.subview %arg0[0, 0, 0] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>> to memref<1x1xf32, strided<[5, 6], offset: ?>>
  %l = memref.load %s[%idx, %idx] : memref<1x1xf32, strided<[5, 6], offset: ?>>
  return %l : f32
}

It fails in the verifier:

/usr/local/google/_blaze_springerm/9abb62e345f5287adac38e0018bb89c9/execroot/google3/blaze-out/k8-dbg/bin/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir.test.runfiles/google3/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir:37:8: error: expected result type with stride = 6 instead of 5 in dim = 0
  %s = memref.subview %arg0[0, 0, 0] [1, 1, 1] [1, 1, 1] : memref<?x?x?xf32, strided<[5, 6, 7], offset: ?>> to memref<1x1xf32, strided<[5, 6], offset: ?>>
       ^
/usr/local/google/_blaze_springerm/9abb62e345f5287adac38e0018bb89c9/execroot/google3/blaze-out/k8-dbg/bin/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir.test.runfiles/google3/third_party/llvm/llvm-project/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir:37:8: note: see current operation: %1 = "memref.reinterpret_cast"(%0#0, %0#1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1, 1>, static_strides = array<i64: 6, 7>}> : (memref<f32>, index) -> memref<1x1xf32, strided<[5, 6], offset: ?>>

Good catch. Let me dig into this a bit, but it will be a while.

To re-iterate, the only real way to fix all of this is to change the memref.subview operation to take explicitly the list of dropped dimensions (and not rely on an auto-magic inference, which IMO can always be tripped up).