Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Add debug info to more Jaxprs and Wrappedfun (step 1) #26078

Merged
merged 1 commit into from
Feb 4, 2025

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jan 24, 2025

The plan is for all core.Jaxpr and lu.WrappedFun to carry
non-None debug info.

We change lu.wrap_init to construct the result paths thunk
whenever it is passed a debug_info. The goal is to make sure that
all WrappedFun have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a WrappedFun or
a Jaxpr.

We obtain several improvements in presence of debug infos
in debug_info_test.py

@gnecula gnecula self-assigned this Jan 24, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jan 24, 2025
@gnecula gnecula marked this pull request as draft January 25, 2025 14:57
@gnecula gnecula force-pushed the debug_info_jaxpr branch 2 times, most recently from 9b7e598 to cfd34b2 Compare January 26, 2025 09:09
@gnecula gnecula changed the title [better_errors] Add debug info to more Jaxprs [better_errors] Add debug info to more Jaxprs and Wrappedfun Jan 26, 2025
@gnecula gnecula changed the title [better_errors] Add debug info to more Jaxprs and Wrappedfun [better_errors] Add debug info to more Jaxprs and Wrappedfun (step 1) Feb 3, 2025
@gnecula gnecula marked this pull request as ready for review February 3, 2025 22:21
@gnecula gnecula requested a review from dfm February 3, 2025 22:26
@gnecula gnecula force-pushed the debug_info_jaxpr branch 2 times, most recently from e489625 to 7f00c56 Compare February 4, 2025 07:50
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
@copybara-service copybara-service bot merged commit 414449e into jax-ml:main Feb 4, 2025
22 of 23 checks passed
@gnecula gnecula deleted the debug_info_jaxpr branch February 4, 2025 18:57
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 5, 2025
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These changes ensure that all the lu.wrap_init and Jaxpr are called with debug_info in the api_test.py:CustomTransposeTest,
api_test.py:CustomVmapTest and api_test.py:RematTest.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, ...
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, ...
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 8, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 9, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).

Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 11, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
gnecula added a commit to gnecula/jax that referenced this pull request Feb 12, 2025
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant