-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelizing linalg.eig
using sharding fails
#26379
Comments
Yeah, this is actually a known limitation of the current implementation of all linalg functions - none of them partition properly! I've been working on fixing that, but in the meantime, the recommended approach is to use shard_map: from functools import partial
from jax.experimental.shard_map import shard_map
@partial(shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_rep=False)
def fun(x):
return jnp.linalg.eig(x)[1] Which should have the performance you expect. |
Great, thank you! |
By the way, there seems to be some unexpected overhead with the implementation you shared, at least when compared to using torch. Here is my code comparing the two: import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=56"
import jax
import jax.numpy as jnp
import numpy as onp
import torch
from functools import partial
from jax.experimental.shard_map import shard_map
torch.set_num_threads(1)
torch.set_num_interop_threads(56)
print(jax.__version__)
print(f"cpu_count={len(jax.devices('cpu'))}")
def test_jax(batch_size, dim=128):
m = jax.random.normal(jax.random.PRNGKey(0), (batch_size, dim, dim))
mesh = jax.make_mesh((batch_size,), ("batch",))
pspec = jax.sharding.PartitionSpec("batch")
sharding = jax.sharding.NamedSharding(mesh, pspec)
mp = jax.device_put(m, sharding)
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_rep=False)
def fun(x):
return jnp.linalg.eig(x)[1]
eigvec = jax.block_until_ready(fun(mp))
%timeit jax.block_until_ready(fun(mp))
def test_torch(batch_size, dim=128):
m = jax.random.normal(jax.random.PRNGKey(0), (batch_size, dim, dim))
m = torch.as_tensor(onp.array(m))
@torch.jit.script
def _eig_torch_parallelized(x):
futures = [torch.jit.fork(torch.linalg.eig, x[i]) for i in range(x.shape[0])]
return [torch.jit.wait(fut) for fut in futures]
%timeit _eig_torch_parallelized(m)
print("\ntesting jax eig")
test_jax(batch_size=1)
test_jax(batch_size=7)
test_jax(batch_size=14)
test_jax(batch_size=28)
test_jax(batch_size=56)
print("\ntesting torch eig")
test_torch(batch_size=1)
test_torch(batch_size=7)
test_torch(batch_size=14)
test_torch(batch_size=28)
test_torch(batch_size=56) Here are my timings:
|
As a note, this seems to be machine dependent. On my mac, it scales as expected, while on my intel machine I get the timings above. However, it does appear to be overhead (as opposed to just machine-related scaling) due to the results I am getting with torch. |
Description
I am trying to run perform nonsymmetric eigendecomposition using
linalg.eig
on several matrices in parallel, sincelinalg.eig
only uses a single CPU core. I don't get the expected acceleration, and in fact the sharding of the input array is not carried over to the output. It seems that despite the sharding, only a single core is being used.Here is the code I am using to attempt parallel eigendecomposition.
I visualized thhe input and output shardings as follows:
which generates the following visualization for the input matrix and output eigenvectors.
And here is a baseline, no-sharding case for reference:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: