This is an archive of the discontinued LLVM Phabricator instance.

[mlir][bufferization] Prevent crash in one shot bufferization with unranked tensor cast
ClosedPublic

Authored by Lewuathe on Apr 25 2023, 11:38 PM.

Details

Summary

One shot bufferization does not support bufferizing the cast between unranked tensors. To prevent the crash, we can check the compatibility of the result type in advance. Reported in https://github.com/llvm/llvm-project/issues/62369.

Diff Detail

Event Timeline

Lewuathe created this revision.Apr 25 2023, 11:38 PM
Herald added a project: Restricted Project. · View Herald Transcript
Lewuathe requested review of this revision.Apr 25 2023, 11:38 PM
Lewuathe updated this revision to Diff 517073.Apr 25 2023, 11:39 PM

Add newline.

springerm added inline comments.Apr 26 2023, 12:01 AM
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
97–101

This can be bufferized. The crash is caused by the fact that unraked->unranked casting is valid on tensor but not on memrefs for some reason. (I think we should handle it the same way on memrefs/tensors.) In that case, we can directly replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);.

Lewuathe added a comment.EditedApr 26 2023, 11:42 PM

@springerm Thanks for the advice.
Lowering memref.cast op with unranked memref for input and output to another dialect (e.g. LLVM) is not allowed for now.

./bin/mlir-opt cast.mlir
cast.mlir:2:8: error: 'memref.cast' op operand type 'memref<4x?xi32>' and result type 'memref<4x?xf32>' are cast incompatible
  %0 = memref.cast %arg : memref<4x?xi32> to memref<4x?xf32>
       ^
cast.mlir:2:8: note: see current operation: %0 = "memref.cast"(%arg0) : (memref<4x?xi32>) -> memref<4x?xf32>

We can support a bufferization from unranked to unranked somehow, but it may cause runtime errors later after lowering more. Casting the unranked tensor to an unranked tensor looks like just passing the responsibility to the bufferization with unranked memref, so I did not find any good hint from that.

So what we can do is to make the pass return the failure instead of catching in the assertion?

Lowering memref.cast op with unranked memref for input and output to another dialect (e.g. LLVM) is not allowed for now.

Really? I thought the reason why we have unranked memrefs is so that we can call external functions. How does this work if we cannot lower it to LLVM?

Casting unranked tensor to unranked tensor looks like just passing the responsibility to the bufferization with unranked memref so I did not find any good hint with that.

Casting unranked -> unranked is a no-op. Such casts should fold away. This could be done during bufferization to work around the inconsistency that unranked->unranked is allowed on tensors but not on memrefs:

if (resultBuffer->getType() == *resultMemRefType) {
  // This cast is a no-op.
  replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
  return success();
}

@springerm

Casting unranked -> unranked is a no-op. Such casts should fold away. This could be done during bufferization to work around the inconsistency that unranked->unranked is allowed on tensors but not on memrefs:

Thanks for the advice. I think that should work. I'll try to eliminate the cast as no-op.

Lewuathe updated this revision to Diff 518119.Apr 28 2023, 7:18 PM

Canonicalize unranked to unranked as no-op.

@springerm No-op casting simply seems to be working as you said. Could you review this change when you get a chance?

@springerm Sorry for bothering you from time to time, but could you review this change when you get a good chance?

springerm accepted this revision.May 18 2023, 3:03 AM

Sorry for the delay, I lost track of some reviews.

This revision is now accepted and ready to land.May 18 2023, 3:03 AM