-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[better_errors] Add debug info to more Jaxprs and Wrappedfun (step 1) #26078
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3f3eeaf
to
02185e7
Compare
02185e7
to
177548e
Compare
9b7e598
to
cfd34b2
Compare
cfd34b2
to
baa6aaa
Compare
baa6aaa
to
08b66f2
Compare
08b66f2
to
da44d6b
Compare
da44d6b
to
8457796
Compare
e489625
to
7f00c56
Compare
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry non-None debug info. We change `lu.wrap_init` to construct the result paths thunk whenever it is passed a `debug_info`. The goal is to make sure that all `WrappedFun` have a debug info with result paths support. We change some calling conventions for internal functions to not pass along a separate debug_info if we have a `WrappedFun` or a `Jaxpr`. We obtain several improvements in presence of debug infos in debug_info_test.py
7f00c56
to
d12aead
Compare
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest, api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info).
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, ...
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, ...
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, tests.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula
added a commit
to gnecula/jax
that referenced
this pull request
Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. These are some leftover changes, in particular those needed when running with `JAX_USE_DIRECT_LINEARIZE=1`.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The plan is for all
core.Jaxpr
andlu.WrappedFun
to carrynon-
None
debug info.We change
lu.wrap_init
to construct the result paths thunkwhenever it is passed a
debug_info
. The goal is to make sure thatall
WrappedFun
have a debug info with result paths support.We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a
WrappedFun
ora
Jaxpr
.We obtain several improvements in presence of debug infos
in debug_info_test.py