This is an archive of the discontinued LLVM Phabricator instance.

[mlir] fix crash in PybindAdaptors.h
ClosedPublic

Authored by ftynse on Jan 14 2022, 8:26 AM.

Details

Summary

The constructor function was being defined without indicating its "init"
name, which made it interpret it as a regular fuction rather than a
constructor. When overload resolution failed, Pybind would attempt to print the
arguments actually passed to the function, including "self", which is not
initialized since the constructor couldn't be called. This would result in
"repr" being called with "self" referencing an uninitialized MLIR C API
object, which in turn would cause undefined behavior when attempting to print
in C++.

Fix this by specifying the correct name.

This in turn uncovers the fact the the mechanism used by PybindAdaptors.h to
bind constructors directly as "init" functions taking "self" is deprecated
by Pybind. Instead, leverage the fact that the adaptors are intended for
attrbutes/types that cannot have additional data members and are all ultimately
instances of "PyAttribute"/"PyType" C++ class. In constructors of derived
classes, construct an instance of the base class first, then steal its internal
pointer to the C++ object to construct the instance of the derived class.

On top of that, the definition of the function was incorrectly indicated as the
method on the "None" object instead of being the method of its parent class.
This would result in a second problem when Pybind would attempt to print
warnings pointing to the parent class since the "None" does not have a
"name" field or its C API equivalent.

Fix this by specifying the correct parent class by looking it up by name in the
parent module.

Diff Detail

Event Timeline

ftynse created this revision.Jan 14 2022, 8:26 AM
ftynse requested review of this revision.Jan 14 2022, 8:26 AM
stellaraccident accepted this revision.Jan 14 2022, 8:46 AM

This is gross - I'm sorry it had caused some hard debugging for you, but thank you.

This use case may be worth an issue/discussion with the pybind folks: I don't have the details anymore but I didn't just make this up, having come across the approach as a suggestion from core pybind devs. Once we leave the fix, let's file an issue with them, describe what we are trying to do and get advice.

This revision is now accepted and ready to land.Jan 14 2022, 8:46 AM
ftynse edited the summary of this revision. (Show Details)Jan 17 2022, 1:27 AM
This revision was automatically updated to reflect the committed changes.
ftynse reopened this revision.Jan 17 2022, 5:47 AM

This didn't work out in some variants. Let's try a different mechanism.

This revision is now accepted and ready to land.Jan 17 2022, 5:47 AM
ftynse updated this revision to Diff 400501.Jan 17 2022, 5:49 AM

Use pybind11 details to implement superclass initialization

stellaraccident accepted this revision.Jan 17 2022, 7:07 AM

Ugh, ugh. Let's definitely ask the pybind folks after landing. I'm also wondering if this part should just be implemented with the python c API. The concept is simpler than the code.

ftynse edited the summary of this revision. (Show Details)Jan 18 2022, 1:20 AM
This revision was automatically updated to reflect the committed changes.

I've started seeing a nondeterministic crash in the JAX test suite that I bisected to this revision.

Here's a typical failure from our CI:
https://source.cloud.google.com/results/invocations/914b1d68-3b59-401c-8417-4080cc48cc35/targets/jax%2Ftesting%2Fcpu%2Fpresubmit_github/log

If you go and look at the python source lines that pytest reports the crash at, one is a call to .verify() and another is a call to .print(). I don't have a more reduced reproduction yet: it's going to be hard to get one because it's a rare crash that I only see when running a large test suite and not when running any tests in isolation. My guess is there's some sort of memory corruption problem.

Can we revert this in the meantime?

I've started seeing a nondeterministic crash in the JAX test suite that I bisected to this revision.

Here's a typical failure from our CI:
https://source.cloud.google.com/results/invocations/914b1d68-3b59-401c-8417-4080cc48cc35/targets/jax%2Ftesting%2Fcpu%2Fpresubmit_github/log

If you go and look at the python source lines that pytest reports the crash at, one is a call to .verify() and another is a call to .print(). I don't have a more reduced reproduction yet: it's going to be hard to get one because it's a rare crash that I only see when running a large test suite and not when running any tests in isolation. My guess is there's some sort of memory corruption problem.

Can we revert this in the meantime?

I'm fine reverting on such evidence.

My suspicion is a refcount/object lifetime issue, but I can't quite see the mechanic. Jax adds a lot of GC pressure and gil context switches that don't show up in isolation, and this would not be the first time it has triggered such a non deterministic failure due to a latent refcount bug. When I have reduced these in the past, I've done it by adding tests with explicit reference release and gc (look for "gc" in the tests to see examples).

There are only a few ways that this patch intersects with what Jax does: by building custom MHLO attributes (which iirc is just for convolution at the moment). I would isolate those code sequences and put them in a unit test with explicit GC used in anger to fuzz it.