Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelizing linalg.eig using sharding fails #26379

Open
mfschubert opened this issue Feb 7, 2025 · 4 comments
Open

Parallelizing linalg.eig using sharding fails #26379

mfschubert opened this issue Feb 7, 2025 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@mfschubert
Copy link

Description

I am trying to run perform nonsymmetric eigendecomposition using linalg.eig on several matrices in parallel, since linalg.eig only uses a single CPU core. I don't get the expected acceleration, and in fact the sharding of the input array is not carried over to the output. It seems that despite the sharding, only a single core is being used.

Here is the code I am using to attempt parallel eigendecomposition.

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
import jax.numpy as jnp

print(f"cpu_count={len(jax.devices('cpu'))}")
m = jax.random.normal(jax.random.PRNGKey(0), (8, 256, 256))

mesh = jax.make_mesh((8,), ("batch",))
pspec = jax.sharding.PartitionSpec("batch")
sharding = jax.sharding.NamedSharding(mesh, pspec)
mp = jax.device_put(m, sharding)

jit_eigvec = jax.jit(lambda x: jnp.linalg.eig(x)[1])
eigvec = jit_eigvec(mp).block_until_ready()
%timeit jax.block_until_ready(jit_eigvec(mp))

# cpu_count=8
# 197 ms ± 1.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

I visualized thhe input and output shardings as follows:

jax.debug.visualize_array_sharding(mp[:, 0, 0])
jax.debug.visualize_array_sharding(eigvec[:, 0, 0])

which generates the following visualization for the input matrix and output eigenvectors.

Image

Image

And here is a baseline, no-sharding case for reference:

import jax
import jax.numpy as jnp

print(f"cpu_count={len(jax.devices('cpu'))}")
m = jax.random.normal(jax.random.PRNGKey(0), (8, 256, 256))

jit_eigvec = jax.jit(lambda x: jnp.linalg.eig(x)[1])
jit_eigvec(m).block_until_ready()
%timeit jax.block_until_ready(jit_eigvec(m))

# cpu_count=1
# 166 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.0
python: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ]
device info: cpu-8, 8 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Mac.attlocal.net', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec  6 18:56:34 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T6020', machine='arm64')
@mfschubert mfschubert added the bug Something isn't working label Feb 7, 2025
@dfm
Copy link
Collaborator

dfm commented Feb 7, 2025

Yeah, this is actually a known limitation of the current implementation of all linalg functions - none of them partition properly! I've been working on fixing that, but in the meantime, the recommended approach is to use shard_map:

from functools import partial
from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_rep=False)
def fun(x):
  return jnp.linalg.eig(x)[1]

Which should have the performance you expect.

@dfm dfm self-assigned this Feb 7, 2025
@mfschubert
Copy link
Author

Great, thank you!

@mfschubert
Copy link
Author

By the way, there seems to be some unexpected overhead with the implementation you shared, at least when compared to using torch. Here is my code comparing the two:

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=56"

import jax
import jax.numpy as jnp
import numpy as onp
import torch
from functools import partial
from jax.experimental.shard_map import shard_map

torch.set_num_threads(1)
torch.set_num_interop_threads(56)

print(jax.__version__)
print(f"cpu_count={len(jax.devices('cpu'))}")

def test_jax(batch_size, dim=128):
    m = jax.random.normal(jax.random.PRNGKey(0), (batch_size, dim, dim))
    mesh = jax.make_mesh((batch_size,), ("batch",))
    pspec = jax.sharding.PartitionSpec("batch")
    sharding = jax.sharding.NamedSharding(mesh, pspec)
    mp = jax.device_put(m, sharding)

    @jax.jit
    @partial(shard_map, mesh=mesh, in_specs=pspec, out_specs=pspec, check_rep=False)
    def fun(x):
        return jnp.linalg.eig(x)[1]

    eigvec = jax.block_until_ready(fun(mp))
    %timeit jax.block_until_ready(fun(mp))


def test_torch(batch_size, dim=128):
    m = jax.random.normal(jax.random.PRNGKey(0), (batch_size, dim, dim))
    m = torch.as_tensor(onp.array(m))

    @torch.jit.script
    def _eig_torch_parallelized(x):
        futures = [torch.jit.fork(torch.linalg.eig, x[i]) for i in range(x.shape[0])]
        return [torch.jit.wait(fut) for fut in futures]
    
    %timeit _eig_torch_parallelized(m)

print("\ntesting jax eig")
test_jax(batch_size=1)
test_jax(batch_size=7)
test_jax(batch_size=14)
test_jax(batch_size=28)
test_jax(batch_size=56)

print("\ntesting torch eig")
test_torch(batch_size=1)
test_torch(batch_size=7)
test_torch(batch_size=14)
test_torch(batch_size=28)
test_torch(batch_size=56)

Here are my timings:

0.4.35
cpu_count=56

testing jax eig
8.71 ms ± 881 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.1 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
52.1 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
67.7 ms ± 3.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
138 ms ± 14.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

testing torch eig
13.7 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
17 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.2 ms ± 308 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.4 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
11 ms ± 228 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@mfschubert
Copy link
Author

As a note, this seems to be machine dependent. On my mac, it scales as expected, while on my intel machine I get the timings above. However, it does appear to be overhead (as opposed to just machine-related scaling) due to the results I am getting with torch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants