Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] scan operator and scan_layers #8620

Open
tengyifei opened this issue Jan 23, 2025 · 1 comment
Open

[RFC] scan operator and scan_layers #8620

tengyifei opened this issue Jan 23, 2025 · 1 comment
Assignees

Comments

@tengyifei
Copy link
Collaborator

RFC: PyTorch/XLA scan operator and scan_layers

Problem statement

Many LLMs have a few dozens of decoder layers that are applied in a for loop. When we trace the forward function of the model, the decoder layers will be unrolled and inlined into the large computation. This may lead to compilation time scaling linearly or super-linearly with the number of decoder layers, compromising the user experience.

We hope to introduce a mechanism to just compile a single layer and use that in an XLA While op. The hypothesis is that compilation time will stay constant as the number of decoder layers increases. The proposed implementation will also reduce the number of times the decoder layer is traced, helping to reduce tracing overhead and improve the performance of PyTorch/XLA on TPUs with small per device batch sizes. There's a prior art in JAX called scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html.

We also need to support backwards/gradient propagation through the scan operator. This is needed so people can use scan during training.

Example usage

We have prototyped two functions, scan and scan_layers, to be elaborated in the next section. Users would generally use scan_layers, which wraps scan under the hood.

Given a sequence of layers self.layers, such as a torch.nn.ModuleList, rather than doing

hidden_states = inputs
for decoder_layer in self.layers:
  hidden_states = decoder_layer(hidden_states)
outputs = hidden_states

Users would write

# self.layers: Iterable[torch.nn.Module] (compatible with `torch.nn.ModuleList` and lists of modules etc.)
# inputs: torch.Tensor
outputs = scan_layers(self.layers, inputs)

where the scan_layers function will trace a single layer, then apply that computation sequentially over the layers, filling in different weights and biases, passing the output from the previous layer as input into the next layer. The scan_layers function will check that the layers have identical structure in terms of weights and biases:

  • They all have the same set of dictionary keys (parameter names)
  • The parameters have the same shapes.

If we want to further ensure the layers perform identically structured computations, we can trace each layer and compare their HLO. Tracing should be faster than end-to-end compilation. That would mean scan would trace i times given input sequence of length i.

High level design

We factor this problem down into two operations: scan and scan_layers.

scan_layers

We have a sequence of layers with identical structure (such as a bunch of nn.Linear(128, 128) or decoder layers). We would like to use the XLA While op to loop over and apply the layers. Specifically, in every iteration we: \

  • Index into the layers and obtain layer i.
  • Pass the hidden state from the previous layer as input to run the layer.
  • Obtain the output and use that as in the hidden state/input in the next iteration.

The challenge is figuring out how to express this in terms of XLA ops. The type system of XLA as specified in its operation semantics supports tensors and tuples. One cannot for example define a C-style structure with named fields and define HLO operations on those. We need to find an appropriate representation of a sequence of torch.nn.Modules in terms of these data types so that the body computation inside the While op can obtain a specific layer using a scalar index.

We propose to stack the weights and biases of these layers into larger tensors where the index of the layer is given by the first dimension. The module.named_parameters() method lets us obtain all the parameters of a module in the form of a dict. Let's say we have 3 linear layers and their state is:

[
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
]

We'll stack them into:

{
  "weight": torch.tensor((3, 64, 64)), "bias": torch.tensor((3, 64)),
}

If you're familiar with the JAX scan function, you may notice that this data structure matches the format expected by jax.lax.scan, which takes a PyTree where leaves are tensors, and calls a user supplied function with a slice of the PyTree, indexing into the leading dimension of each leaf. So e.g. if the dict above is param, it would call the user function fn with

fn({
  "weight": param["weight"][i],
  "bias": param["bias"][i],
})

etc. We'll implement scan to similarly support generalizing to arbitrary PyTrees. Because the XLA type system only consists of tensors and tuples, scan needs to flatten the dictionary into a list of tensors and supply those as parameters to the XLA computation, and similarly for the output. This way scan_layers can pass this stacked tensor dictionary as input. The fn supplied by scan_layers will rebuild a Linear, plugging in the parameters at layer i, and invoke the layer on the inputs. Fortunately, PyTorch already has built enough utilities around module parameter handling so we can implement this with minimal hassle. See the scan_layers implementation.

scan

The design of scan follows the JAX version very closely. I'll describe the noteworthy things when interfacing with HLO/XLA. At a high level this function will:

  • Trace the user supplied fn using fake inputs (torch.empty tensors) to obtain an XlaComputation object.

  • Inspect the computation to get all the referenced tensors (xla::device_data nodes) and their ordering when supplied as parameters. There may be more parameters than in the method arguments of fn, because fn internally may capture more tensors in the function closure, or it may create more tensors during the execution. Example:

      def fn(carry, x):
        foo = torch.zeros(8)
        return carry, x + foo
    

Gets lowered into an HLO like this

  HloModule FnComputation, entry_computation_layout={
    (f32[8], f32[8], f32[8]) -> (f32[8], f32[8])
  }
One of the `f32[8]` parameters corresponds to the `foo` tensor created within the body of `fn`. We need to identify that parameter and supply the correct tensor value when building the computation.
  • Builds a While op with a cond_fn and a body_fn. The cond_fn determines if the current iteration (represented as another parameter to the computation) is zero, where we should exit. The body_fn calls the fn computation with the correct parameter ordering and supplies additional tensor parameters as necessary.

The backwards of the scan is also implemented using scan. At a high level, we scan the backwards version of fn in the reverse order from the last input to the first. There are several ways to extract the backward of fn. We'll start with AOTAutgrad and explore using Dynamo to extract the backward.

Interface design

def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  """Apply a function over leading dimension of tensors while carrying along state.
  
  This is similar to the JAX `jax.lax.scan` function found in [1].
  
  You may use it to loop over the leading dimension of tensors efficiently. If `xs`
  is a single tensor, this function is roughly equal to the following Python code:

    def scan(fn, init, xs):
      ys = []
      carry = init
      for i in len(range(xs.size(0))):
        carry, y = fn(carry, xs[i])
        ys.append(y)
      return carry, torch.stack(ys, dim=0)
  
  In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function
  will iterate through the leading dimension of every leaf element of `xs` simultaneously,
  and pass a slice of those elements to `fn` as another PyTree. This means you may
  scan over multiple tensors and produce multiple output tensors at once.
  
  Args:

    fn: a Python callable that accepts two PyTrees of tensors: the carry object and the
        slices of `xs` along its leading dimension. It should return two PyTrees: the carry
        object and the slices of the output. The returned carry object will be passed to
        the next invocation of `fn`.

    init: the initial carry object passed to the first invocation of `fn`.
    
    xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along
        the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of
        tensor), `fn` will get PyTrees of slices. In that case the leading dimension size
        of the leaves in the PyTree must be the same.

  Returns:

    (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
    `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
    by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns
    `(carry, (y1, y2))` then this function will return
    `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`.

  [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
  """
  ...


init = torch.tensor(0)
xs = torch.tensor([1, 2, 3])
fn = lambda carry, x: carry + 1, x + carry
scan(fn, init, xs)
# Returns:
# (3, [1, 3, 5])
def scan_layers(layers: Iterable[torch.nn.Module], input_data: torch.Tensor):
  """Applies each layer in `layers` to `input_data` sequentially.

  `input_data` is provided as input to the first layer in `layers`. The output of one
  layer is provided as input to next layer. This function is equivalent to

    sequential = torch.nn.Sequential(layers)
    sequential(input_data)

  This function can be faster to compile since it reuses the XLA computation of the
  first layer to perform the computation of all other layers.
  """
  ...

Alternatives

An alternative implementation of scan_layers is to combine the weights and biases of different layers into an XLA Tuple, as opposed to stacking them into a Tensor. The body computation of the While op will index into the Tuple using get_tuple_element as opposed to indexing into the Tensor using dynamic_slice. This implies that scan_layers won't use scan, which is designed around indexing into tensors.

I have not prototyped this approach as I'm not sure if XLA supports nested tuples of distinctly shaped tensors. One may suspect the Tuple has better performance than Tensor due to the need to stack tensors into a larger tensor, but given that XLA backends aggressively optimizes memory layouts, I'm not sure how much this will pay off.

One advantage of the main proposal is that we'll expose a familiar scan operator that has near-identical semantics to the scan operator found in JAX, lowering the learning barrier.

In any case, these are things we can optimize in future versions of PyTorch/XLA without changing the signature of scan_layers.

@tengyifei tengyifei self-assigned this Jan 23, 2025
@tengyifei
Copy link
Collaborator Author

In the first iteration of torch_xla.experimental.scan, we're using AOTAutograd to figure out the backward of the combine function, in order to implement the backward of the scan function. This is the easiest to make this functionality available earlier in experimental status. This has some drawbacks:

  • The combine function cannot use higher order ops (e.g. torch.utils.checkpoint)
  • AOTAutograd is not intended by PyTorch for external libraries to use

Next, we should focus on general availability:

  • The constraint is that we should keep user code changes to a minimum when people adopt scan.
  • Support functions that use torch.utils.checkpoint, as opposed to AOT rematerialization.
  • This will also let us access Torch upstream selective checkpoint features and memory budget based checkpointing, c.f. https://static.sched.com/hosted_files/pytorch2024/59/New%20Activation%20Checkpointing%20APIs%20in%20PyTorch.pdf
  • It's likely that we'll end up supporting this by lowering the PyTorch upstream torch.scan higher order operator inside a Dynamo backend. Dynamo will take care of decomposing torch.utils.checkpoint automatically and lifting global tensor references into function arguments, so external in-place mutations are transformed into internal in-place mutations.
    After this milestone, we can move scan out of experimental.

See discussion in #7901 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant