You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Many LLMs have a few dozens of decoder layers that are applied in a for loop. When we trace the forward function of the model, the decoder layers will be unrolled and inlined into the large computation. This may lead to compilation time scaling linearly or super-linearly with the number of decoder layers, compromising the user experience.
We hope to introduce a mechanism to just compile a single layer and use that in an XLA While op. The hypothesis is that compilation time will stay constant as the number of decoder layers increases. The proposed implementation will also reduce the number of times the decoder layer is traced, helping to reduce tracing overhead and improve the performance of PyTorch/XLA on TPUs with small per device batch sizes. There's a prior art in JAX called scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html.
We also need to support backwards/gradient propagation through the scan operator. This is needed so people can use scan during training.
Example usage
We have prototyped two functions, scan and scan_layers, to be elaborated in the next section. Users would generally use scan_layers, which wraps scan under the hood.
Given a sequence of layers self.layers, such as a torch.nn.ModuleList, rather than doing
hidden_states = inputs
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
outputs = hidden_states
Users would write
# self.layers: Iterable[torch.nn.Module] (compatible with `torch.nn.ModuleList` and lists of modules etc.)
# inputs: torch.Tensor
outputs = scan_layers(self.layers, inputs)
where the scan_layers function will trace a single layer, then apply that computation sequentially over the layers, filling in different weights and biases, passing the output from the previous layer as input into the next layer. The scan_layers function will check that the layers have identical structure in terms of weights and biases:
They all have the same set of dictionary keys (parameter names)
The parameters have the same shapes.
If we want to further ensure the layers perform identically structured computations, we can trace each layer and compare their HLO. Tracing should be faster than end-to-end compilation. That would mean scan would trace i times given input sequence of length i.
High level design
We factor this problem down into two operations: scan and scan_layers.
scan_layers
We have a sequence of layers with identical structure (such as a bunch of nn.Linear(128, 128) or decoder layers). We would like to use the XLA While op to loop over and apply the layers. Specifically, in every iteration we: \
Index into the layers and obtain layer i.
Pass the hidden state from the previous layer as input to run the layer.
Obtain the output and use that as in the hidden state/input in the next iteration.
The challenge is figuring out how to express this in terms of XLA ops. The type system of XLA as specified in its operation semantics supports tensors and tuples. One cannot for example define a C-style structure with named fields and define HLO operations on those. We need to find an appropriate representation of a sequence of torch.nn.Modules in terms of these data types so that the body computation inside the While op can obtain a specific layer using a scalar index.
We propose to stack the weights and biases of these layers into larger tensors where the index of the layer is given by the first dimension. The module.named_parameters() method lets us obtain all the parameters of a module in the form of a dict. Let's say we have 3 linear layers and their state is:
If you're familiar with the JAX scan function, you may notice that this data structure matches the format expected by jax.lax.scan, which takes a PyTree where leaves are tensors, and calls a user supplied function with a slice of the PyTree, indexing into the leading dimension of each leaf. So e.g. if the dict above is param, it would call the user function fn with
etc. We'll implement scan to similarly support generalizing to arbitrary PyTrees. Because the XLA type system only consists of tensors and tuples, scan needs to flatten the dictionary into a list of tensors and supply those as parameters to the XLA computation, and similarly for the output. This way scan_layers can pass this stacked tensor dictionary as input. The fn supplied by scan_layers will rebuild a Linear, plugging in the parameters at layer i, and invoke the layer on the inputs. Fortunately, PyTorch already has built enough utilities around module parameter handling so we can implement this with minimal hassle. See the scan_layers implementation.
scan
The design of scan follows the JAX version very closely. I'll describe the noteworthy things when interfacing with HLO/XLA. At a high level this function will:
Trace the user supplied fn using fake inputs (torch.empty tensors) to obtain an XlaComputation object.
Inspect the computation to get all the referenced tensors (xla::device_data nodes) and their ordering when supplied as parameters. There may be more parameters than in the method arguments of fn, because fn internally may capture more tensors in the function closure, or it may create more tensors during the execution. Example:
One of the `f32[8]` parameters corresponds to the `foo` tensor created within the body of `fn`. We need to identify that parameter and supply the correct tensor value when building the computation.
Builds a While op with a cond_fn and a body_fn. The cond_fn determines if the current iteration (represented as another parameter to the computation) is zero, where we should exit. The body_fn calls the fn computation with the correct parameter ordering and supplies additional tensor parameters as necessary.
The backwards of the scan is also implemented using scan. At a high level, we scan the backwards version of fn in the reverse order from the last input to the first. There are several ways to extract the backward of fn. We'll start with AOTAutgrad and explore using Dynamo to extract the backward.
Interface design
defscan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) ->tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state. This is similar to the JAX `jax.lax.scan` function found in [1]. You may use it to loop over the leading dimension of tensors efficiently. If `xs` is a single tensor, this function is roughly equal to the following Python code: def scan(fn, init, xs): ys = [] carry = init for i in len(range(xs.size(0))): carry, y = fn(carry, xs[i]) ys.append(y) return carry, torch.stack(ys, dim=0) In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function will iterate through the leading dimension of every leaf element of `xs` simultaneously, and pass a slice of those elements to `fn` as another PyTree. This means you may scan over multiple tensors and produce multiple output tensors at once. Args: fn: a Python callable that accepts two PyTrees of tensors: the carry object and the slices of `xs` along its leading dimension. It should return two PyTrees: the carry object and the slices of the output. The returned carry object will be passed to the next invocation of `fn`. init: the initial carry object passed to the first invocation of `fn`. xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of tensor), `fn` will get PyTrees of slices. In that case the leading dimension size of the leaves in the PyTree must be the same. Returns: (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns `(carry, (y1, y2))` then this function will return `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`. [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html """
...
init=torch.tensor(0)
xs=torch.tensor([1, 2, 3])
fn=lambdacarry, x: carry+1, x+carryscan(fn, init, xs)
# Returns:# (3, [1, 3, 5])
defscan_layers(layers: Iterable[torch.nn.Module], input_data: torch.Tensor):
"""Applies each layer in `layers` to `input_data` sequentially. `input_data` is provided as input to the first layer in `layers`. The output of one layer is provided as input to next layer. This function is equivalent to sequential = torch.nn.Sequential(layers) sequential(input_data) This function can be faster to compile since it reuses the XLA computation of the first layer to perform the computation of all other layers. """
...
Alternatives
An alternative implementation of scan_layers is to combine the weights and biases of different layers into an XLA Tuple, as opposed to stacking them into a Tensor. The body computation of the While op will index into the Tuple using get_tuple_element as opposed to indexing into the Tensor using dynamic_slice. This implies that scan_layers won't use scan, which is designed around indexing into tensors.
I have not prototyped this approach as I'm not sure if XLA supports nested tuples of distinctly shaped tensors. One may suspect the Tuple has better performance than Tensor due to the need to stack tensors into a larger tensor, but given that XLA backends aggressively optimizes memory layouts, I'm not sure how much this will pay off.
One advantage of the main proposal is that we'll expose a familiar scan operator that has near-identical semantics to the scan operator found in JAX, lowering the learning barrier.
In any case, these are things we can optimize in future versions of PyTorch/XLA without changing the signature of scan_layers.
The text was updated successfully, but these errors were encountered:
In the first iteration of torch_xla.experimental.scan, we're using AOTAutograd to figure out the backward of the combine function, in order to implement the backward of the scan function. This is the easiest to make this functionality available earlier in experimental status. This has some drawbacks:
The combine function cannot use higher order ops (e.g. torch.utils.checkpoint)
AOTAutograd is not intended by PyTorch for external libraries to use
Next, we should focus on general availability:
The constraint is that we should keep user code changes to a minimum when people adopt scan.
Support functions that use torch.utils.checkpoint, as opposed to AOT rematerialization.
It's likely that we'll end up supporting this by lowering the PyTorch upstream torch.scan higher order operator inside a Dynamo backend. Dynamo will take care of decomposing torch.utils.checkpoint automatically and lifting global tensor references into function arguments, so external in-place mutations are transformed into internal in-place mutations.
After this milestone, we can move scan out of experimental.
RFC: PyTorch/XLA
scan
operator andscan_layers
Problem statement
Many LLMs have a few dozens of decoder layers that are applied in a for loop. When we trace the forward function of the model, the decoder layers will be unrolled and inlined into the large computation. This may lead to compilation time scaling linearly or super-linearly with the number of decoder layers, compromising the user experience.
We hope to introduce a mechanism to just compile a single layer and use that in an XLA
While
op. The hypothesis is that compilation time will stay constant as the number of decoder layers increases. The proposed implementation will also reduce the number of times the decoder layer is traced, helping to reduce tracing overhead and improve the performance of PyTorch/XLA on TPUs with small per device batch sizes. There's a prior art in JAX calledscan
: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html.We also need to support backwards/gradient propagation through the scan operator. This is needed so people can use
scan
during training.Example usage
We have prototyped two functions,
scan
andscan_layers
, to be elaborated in the next section. Users would generally usescan_layers
, which wrapsscan
under the hood.Given a sequence of layers
self.layers
, such as atorch.nn.ModuleList
, rather than doingUsers would write
where the
scan_layers
function will trace a single layer, then apply that computation sequentially over the layers, filling in different weights and biases, passing the output from the previous layer as input into the next layer. Thescan_layers
function will check that the layers have identical structure in terms of weights and biases:If we want to further ensure the layers perform identically structured computations, we can trace each layer and compare their HLO. Tracing should be faster than end-to-end compilation. That would mean
scan
would tracei
times given input sequence of lengthi
.High level design
We factor this problem down into two operations:
scan
andscan_layers
.scan_layers
We have a sequence of layers with identical structure (such as a bunch of
nn.Linear(128, 128)
or decoder layers). We would like to use the XLAWhile
op to loop over and apply the layers. Specifically, in every iteration we: \i
.The challenge is figuring out how to express this in terms of XLA ops. The type system of XLA as specified in its operation semantics supports tensors and tuples. One cannot for example define a C-style structure with named fields and define HLO operations on those. We need to find an appropriate representation of a sequence of
torch.nn.Module
s in terms of these data types so that the body computation inside theWhile
op can obtain a specific layer using a scalar index.We propose to stack the weights and biases of these layers into larger tensors where the index of the layer is given by the first dimension. The
module.named_parameters()
method lets us obtain all the parameters of a module in the form of adict
. Let's say we have 3 linear layers and their state is:We'll stack them into:
If you're familiar with the JAX scan function, you may notice that this data structure matches the format expected by
jax.lax.scan
, which takes a PyTree where leaves are tensors, and calls a user supplied function with a slice of the PyTree, indexing into the leading dimension of each leaf. So e.g. if the dict above isparam
, it would call the user functionfn
withetc. We'll implement
scan
to similarly support generalizing to arbitrary PyTrees. Because the XLA type system only consists of tensors and tuples,scan
needs to flatten the dictionary into a list of tensors and supply those as parameters to the XLA computation, and similarly for the output. This wayscan_layers
can pass this stacked tensor dictionary as input. Thefn
supplied byscan_layers
will rebuild aLinear
, plugging in the parameters at layeri
, and invoke the layer on the inputs. Fortunately, PyTorch already has built enough utilities around module parameter handling so we can implement this with minimal hassle. See the scan_layers implementation.scan
The design of
scan
follows the JAX version very closely. I'll describe the noteworthy things when interfacing with HLO/XLA. At a high level this function will:Trace the user supplied
fn
using fake inputs (torch.empty
tensors) to obtain anXlaComputation
object.Inspect the computation to get all the referenced tensors (
xla::device_data
nodes) and their ordering when supplied as parameters. There may be more parameters than in the method arguments offn
, becausefn
internally may capture more tensors in the function closure, or it may create more tensors during the execution. Example:Gets lowered into an HLO like this
While
op with acond_fn
and abody_fn
. Thecond_fn
determines if the current iteration (represented as another parameter to the computation) is zero, where we should exit. Thebody_fn
calls thefn
computation with the correct parameter ordering and supplies additional tensor parameters as necessary.The backwards of the
scan
is also implemented usingscan
. At a high level, we scan the backwards version offn
in the reverse order from the last input to the first. There are several ways to extract the backward offn
. We'll start with AOTAutgrad and explore using Dynamo to extract the backward.Interface design
Alternatives
An alternative implementation of
scan_layers
is to combine the weights and biases of different layers into an XLATuple
, as opposed to stacking them into aTensor
. The body computation of theWhile
op will index into theTuple
usingget_tuple_element
as opposed to indexing into theTensor
usingdynamic_slice
. This implies thatscan_layers
won't usescan
, which is designed around indexing into tensors.I have not prototyped this approach as I'm not sure if XLA supports nested tuples of distinctly shaped tensors. One may suspect the
Tuple
has better performance thanTensor
due to the need to stack tensors into a larger tensor, but given that XLA backends aggressively optimizes memory layouts, I'm not sure how much this will pay off.One advantage of the main proposal is that we'll expose a familiar
scan
operator that has near-identical semantics to the scan operator found in JAX, lowering the learning barrier.In any case, these are things we can optimize in future versions of PyTorch/XLA without changing the signature of
scan_layers
.The text was updated successfully, but these errors were encountered: