-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] torch_xla.step
context manager
#6751
Comments
for
I think I changed the default behavior in one of recent prs. Now trying to access the result of a tensor will actually materalize the tensor. However if you have
above code will only materalize |
I like this idea in general, curious how the final ux looks like. We don't need to worry about xla/torch_xla/core/dynamo_bridge.py Lines 466 to 468 in 6ac3223
|
I like the proposal, especially the When the users develop their models using torch_xla, I feel it's very challenging to ignore torch_xla internals. Users tend to do To hide the torch_xla internals completely, maybe HF Accelerate/Lightning is the way to go? Also, at xla/test/spmd/test_train_spmd_imagenet.py Line 139 in 9ff0f31
|
TL;DR
mark_step
with a context manager/decoratortorch_xla.step
to explicitly mark code that should be traced and then executed.mark_step
is always executed upon exit and provides an explicit context to raise errors.Introduction
Users should be able to pick up and use PyTorch/XLA without having to understand the execution model. As much as possible, we should bury the implementation details such that you can wrap existing PyTorch code and have it "just work".
xm.mark_step
is the most obvious example where we directly expose implementation details of how PyTorch/XLA works to the user: we actually require users to decide where to place synchronization barriers. These manual barriers are not required in JAX and TensorFlow, even though both of them implement a similar lazy execution model.In practice, our current solution is to hide
xm.mark_step
in our preloading data loader implementation that calls it 1) after each batch is loaded and 2) at the end of iteration. If a user is sufficiently careful, they don't have to see this implementation detail at all. Take the following training loop for example:Training loops like the one above will run without issue. However, even slight deviations from this pattern can cause inscrutable problems such as the one below.
Weird
xm.mark_step
behaviorsCheckpointing before
xm.mark_step
Take the following common pattern, where the master saves a checkpoint every n steps:
Can you spot the error? The above code will hang forever without printing an error at all, because the execution diverges between the master and non-master replicas. Compare the order of operations:
Concretely, a version of this same bug took hours to debug when it appeared in both Accelerate and Lightning. The solution in each case was to add an additional
xm.mark_step
on replica 0 before the checkpoint.Logging before
xm.mark_step
Let's take a similar example where we log the loss of a model every n steps (either to the terminal or to TensorBoard):
Although the above code looks innocent, it will (at the very least) lead to a substantial performance degradation. The forward pass will actually run twice on the master. When you move a tensor to the CPU, before a mark step, we immediately run any pending operations on that tensor, but we don't cache the result 1. The forward pass (along with the backward pass) will run again when
loader
callsmark_step
.Returning early before
xm.mark_step
Because we rely on the
MpDeviceLoader
iterator for insertingmark_step
s,break
ing early or raising an exception becomes a dangerous operation. Take the following example with early stopping:When the loop exits, the graph will include the entire forward pass of the model, which will get run the next time the user calls
mark_step
. This becomes problematic if the user adds another large executable to the graph. If the user runsmark_step
after This was a latent bug with the HuggingFacetransformers.Trainer
and was only fixed recently. In this case, it led to OOMs during checkpointing.xm.mark_step
inside profile scopeThe above examples deal with cases where we effectively require a
mark_step
. So let's say the user does add themark_step
s they need manually, but they are also profiling their code. In this case, they are going to be susceptible to the mistake below:Running code like this will result in an error like this one:
RuntimeError: Expecting scope to be empty but it is train_loop.
We're essentially taking an indentation mistake (putting the log and the
mark_step
inside of aTrace
and makes it a flat-out error with an unclear error message.Proposal
The "correct" way to avoid most of the above errors is to use a step closure to defer an operation until after the
xm.mark_step
runs. This again exposes implementation details directly to the user. We should not be framing our API to the user such that they have to carefully think through the execution order of their code, particularly because our execution model diverges from what most people are used to intorch
.Fundamentally, the problem here is that what you can and cannot call at a given line depends on an invisible global state that is not readily visible. Because of our API's imperative structure, we cannot tell what code is between
mark_step
s and select the correct behavior or raise a clear error message. Python gives us two tools to mark a limited context explicitly: context managers and decorators2.My proposal is simple: create a decorator/context manager3 named something like
torch_xla.step
that explicitly marks out a segment of code that should be traced. If a user accidentally does something that they shouldn't (like move a tensor to CPU) in this context, we can loudly raise an error.Example code
The following example code is modified from our ResNet example code.
Before:
After:
If the user puts
train_update
inside oftorch_xla.step
, we can raise an error early when moving thexla
tensor for loss to the CPU. The same would be true if they tried checkpointing there. Likewise, if they return early from the loop, the context manager's__exit__()
will still be called.Prior work:
xp.StepTrace
We already have a similar context manager in our API:
StepTrace
.StepTrace
is already a context manager that callsmark_step
when it exits. Since it's focused on profiling,StepTrace
is rarely used in our documentation and examples. I want to take this idea further. Most importantly,torch_xla.step
should be used as a context to proactively raise errors for incorrect or dangerous operations, and we should start using it in documentation that uses custom training loops.We still need
xm.mark_step
We can't eliminate
mark_step
entirely.torch_xla.step
will be useful for training loops that are written with PyTorch/XLA in mind. It may be less useful for use with higher-level frameworks that control the training loop. It's generally harder to insert a new context manager without modifying code than to insert a single call toxm.mark_step
. Take our HuggingFace Diffusers example:mark_step
is a useful callback to pass in this case.Interaction with
MpDeviceLoader
We use
MpDeviceLoader
extensively throughout our documentation, sotorch_xla.step
needs to be compatible with it4. Sincemark_step
doesn't affect the host-to-device transfers thatMpDeviceLoader
starts, the effect of adding an extramark_step
will be minimal. If there was an unexpected interaction, there is already a mechanism to prevent MpDeviceLoader from dispatching computations, although I'd like to avoid that. We should reduce the coupling between our features such that they can be added gradually as needed.Closing thoughts
This is not a perfect solution. The problem as I see it is that the user must keep in mind some context about the XLA state, but we make that state invisible. My proposal here is to instead make that state visible. Explicit is better than implicit.
I'd rather the user not have to think about XLA at all, but I don't see a good way to do that entirely. I'm absolutely open to other ideas and ways of thinking about the problems above.
torch.compile
backend somehow?mark_step
horror stories did I miss?Footnotes
There are good reasons for this: caching intermediate results will make it harder to effectively cache the whole executable. ↩
TensorFlow and JAX have a similar challenge of translating an eager-looking API based on Numpy into an effectively lazy one, and decorators are the path that both of them take. See
jax.jit
andtensorflow.function
. ↩Python makes it easy to combine these two into one implementation: https://docs.python.org/3/library/contextlib.html#contextlib.ContextDecorator ↩
In my experimentation,
MpDeviceLoader
doesn't make much (if any difference) to MP workloads because transfers are now async anyway. That's a can of worms for another issue/proposal. ↩The text was updated successfully, but these errors were encountered: