Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] torch_xla.step context manager #6751

Closed
will-cromar opened this issue Mar 14, 2024 · 4 comments
Closed

[RFC] torch_xla.step context manager #6751

will-cromar opened this issue Mar 14, 2024 · 4 comments
Assignees
Labels
RFC usability Bugs/features related to improving the usability of PyTorch/XLA

Comments

@will-cromar
Copy link
Collaborator

TL;DR

  • In most cases, replace mark_step with a context manager/decorator torch_xla.step to explicitly mark code that should be traced and then executed.
  • Both ensures that mark_step is always executed upon exit and provides an explicit context to raise errors.

Introduction

Users should be able to pick up and use PyTorch/XLA without having to understand the execution model. As much as possible, we should bury the implementation details such that you can wrap existing PyTorch code and have it "just work".

xm.mark_step is the most obvious example where we directly expose implementation details of how PyTorch/XLA works to the user: we actually require users to decide where to place synchronization barriers. These manual barriers are not required in JAX and TensorFlow, even though both of them implement a similar lazy execution model.

In practice, our current solution is to hide xm.mark_step in our preloading data loader implementation that calls it 1) after each batch is loaded and 2) at the end of iteration. If a user is sufficiently careful, they don't have to see this implementation detail at all. Take the following training loop for example:

for batch in loader:
  # Run model
  xm.optimizer_step() # Note: this may be optimizer.step() when using DDP

Training loops like the one above will run without issue. However, even slight deviations from this pattern can cause inscrutable problems such as the one below.

Weird xm.mark_step behaviors

Checkpointing before xm.mark_step

Take the following common pattern, where the master saves a checkpoint every n steps:

for batch in loader:
  # Run model
  xm.optimizer_step()
  if is_master and step % 100 == 0:
    xm.save(model)

Can you spot the error? The above code will hang forever without printing an error at all, because the execution diverges between the master and non-master replicas. Compare the order of operations:

Replica 0 Others
load data
mark step
run step
checkpoint, hang forever waiting for other replicas
load data
mark step
run step
load data
mark step, hang forever waiting for replica 0

Concretely, a version of this same bug took hours to debug when it appeared in both Accelerate and Lightning. The solution in each case was to add an additional xm.mark_step on replica 0 before the checkpoint.

Logging before xm.mark_step

Let's take a similar example where we log the loss of a model every n steps (either to the terminal or to TensorBoard):

for batch in loader:
  # Run model
  loss = ...
  xm.optimizer_step()
  if is_master and step % 100 == 0:
    print(step, loss)

Although the above code looks innocent, it will (at the very least) lead to a substantial performance degradation. The forward pass will actually run twice on the master. When you move a tensor to the CPU, before a mark step, we immediately run any pending operations on that tensor, but we don't cache the result 1. The forward pass (along with the backward pass) will run again when loader calls mark_step.

Returning early before xm.mark_step

Because we rely on the MpDeviceLoader iterator for inserting mark_steps, breaking early or raising an exception becomes a dangerous operation. Take the following example with early stopping:

for batch in loader:
  # Run model
  loss = ...
  if loss < target:
    break

When the loop exits, the graph will include the entire forward pass of the model, which will get run the next time the user calls mark_step. This becomes problematic if the user adds another large executable to the graph. If the user runs mark_step after This was a latent bug with the HuggingFace transformers.Trainer and was only fixed recently. In this case, it led to OOMs during checkpointing.

xm.mark_step inside profile scope

The above examples deal with cases where we effectively require a mark_step. So let's say the user does add the mark_steps they need manually, but they are also profiling their code. In this case, they are going to be susceptible to the mistake below:

with xp.Trace('loss'):
  loss =xm.mark_step()
  if is_master and step % 100 == 0:
    print(step, loss)

Running code like this will result in an error like this one: RuntimeError: Expecting scope to be empty but it is train_loop.

We're essentially taking an indentation mistake (putting the log and the mark_step inside of a Trace and makes it a flat-out error with an unclear error message.

Proposal

The "correct" way to avoid most of the above errors is to use a step closure to defer an operation until after the xm.mark_step runs. This again exposes implementation details directly to the user. We should not be framing our API to the user such that they have to carefully think through the execution order of their code, particularly because our execution model diverges from what most people are used to in torch.

Fundamentally, the problem here is that what you can and cannot call at a given line depends on an invisible global state that is not readily visible. Because of our API's imperative structure, we cannot tell what code is between mark_steps and select the correct behavior or raise a clear error message. Python gives us two tools to mark a limited context explicitly: context managers and decorators2.

My proposal is simple: create a decorator/context manager3 named something like torch_xla.step that explicitly marks out a segment of code that should be traced. If a user accidentally does something that they shouldn't (like move a tensor to CPU) in this context, we can loudly raise an error.

Example code

The following example code is modified from our ResNet example code.

Before:

for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()
  xm.optimizer_step(optimizer)
  
  xm.add_step_closure(
       train_update, args=(device, step, loss, tracker, epoch, writer))

After:

for step, (data, target) in enumerate(loader):
  with torch_xla.step():
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)
  
  train_update(device, step, loss, tracker, epoch, writer)

If the user puts train_update inside of torch_xla.step, we can raise an error early when moving the xla tensor for loss to the CPU. The same would be true if they tried checkpointing there. Likewise, if they return early from the loop, the context manager's __exit__() will still be called.

Prior work: xp.StepTrace

We already have a similar context manager in our API: StepTrace. StepTrace is already a context manager that calls mark_step when it exits. Since it's focused on profiling, StepTrace is rarely used in our documentation and examples. I want to take this idea further. Most importantly, torch_xla.step should be used as a context to proactively raise errors for incorrect or dangerous operations, and we should start using it in documentation that uses custom training loops.

We still need xm.mark_step

We can't eliminate mark_step entirely. torch_xla.step will be useful for training loops that are written with PyTorch/XLA in mind. It may be less useful for use with higher-level frameworks that control the training loop. It's generally harder to insert a new context manager without modifying code than to insert a single call to xm.mark_step. Take our HuggingFace Diffusers example:

image = pipeline(prompt, callback=lambda *args: xm.mark_step(), generator=generator)

mark_step is a useful callback to pass in this case.

Interaction with MpDeviceLoader

We use MpDeviceLoader extensively throughout our documentation, so torch_xla.step needs to be compatible with it4. Since mark_step doesn't affect the host-to-device transfers that MpDeviceLoader starts, the effect of adding an extra mark_step will be minimal. If there was an unexpected interaction, there is already a mechanism to prevent MpDeviceLoader from dispatching computations, although I'd like to avoid that. We should reduce the coupling between our features such that they can be added gradually as needed.

Closing thoughts

This is not a perfect solution. The problem as I see it is that the user must keep in mind some context about the XLA state, but we make that state invisible. My proposal here is to instead make that state visible. Explicit is better than implicit.

I'd rather the user not have to think about XLA at all, but I don't see a good way to do that entirely. I'm absolutely open to other ideas and ways of thinking about the problems above.

  • Can this be combined with our torch.compile backend somehow?
  • Can we come up with a better name? I'll happily update this proposal. Naming is hard.
  • What other mark_step horror stories did I miss?

Footnotes

  1. There are good reasons for this: caching intermediate results will make it harder to effectively cache the whole executable.

  2. TensorFlow and JAX have a similar challenge of translating an eager-looking API based on Numpy into an effectively lazy one, and decorators are the path that both of them take. See jax.jit and tensorflow.function.

  3. Python makes it easy to combine these two into one implementation: https://docs.python.org/3/library/contextlib.html#contextlib.ContextDecorator

  4. In my experimentation, MpDeviceLoader doesn't make much (if any difference) to MP workloads because transfers are now async anyway. That's a can of worms for another issue/proposal.

@will-cromar will-cromar self-assigned this Mar 14, 2024
@will-cromar will-cromar added usability Bugs/features related to improving the usability of PyTorch/XLA RFC labels Mar 14, 2024
@JackCaoG
Copy link
Collaborator

for

The forward pass will actually run twice on the master. When you move a tensor to the CPU, before a mark step, we immediately run any pending operations on that tensor, but we don't cache the result

I think I changed the default behavior in one of recent prs. Now trying to access the result of a tensor will actually materalize the tensor. However if you have

a = torch.randn(5, 5)
b = torch.randn(5, 5)
c = a + b
d = c * a
print(d)

above code will only materalize d, a, b and c will be materialized by a later mark_step().

@JackCaoG
Copy link
Collaborator

I like this idea in general, curious how the final ux looks like. We don't need to worry about torch.compile. torch.compile itself served as a context manager in a sense and give us enough freedom to mark_step when entering/exciting the compiled function. For example

if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

@vanbasten23
Copy link
Collaborator

I like the proposal, especially the [Explicit is better than implicit](https://peps.python.org/pep-0020/). motivation.

When the users develop their models using torch_xla, I feel it's very challenging to ignore torch_xla internals. Users tend to do print and it has subtle perf implications and they have to understand some torch_xla internals.

To hide the torch_xla internals completely, maybe HF Accelerate/Lightning is the way to go?

Also, at If the user puts train_update inside of torch_xla.step, could be better if we include the implementation of train_update

def _train_update(device, step, loss, tracker, epoch, writer):

@will-cromar
Copy link
Collaborator Author

Marking this complete. step was added in #7068 (renamed to compile in #7750) and @JackCaoG implemented an optional guard against accidental execution in #7776.

This is already a significant upgrade over mark_step. We'll continue to build out compile in as part of our move to eager mode in #7253.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

No branches or pull requests

3 participants