Skip to content

Commit

Permalink
Fix some busted batching rules in lax.linalg.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726543703
  • Loading branch information
dfm authored and Google-ML-Automation committed Feb 13, 2025
1 parent 7f99929 commit ea4e324
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,7 @@ def _householder_product_batching_rule(batched_args, batch_dims):
a, taus = batched_args
b_a, b_taus, = batch_dims
return householder_product(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,)
batching.moveaxis(taus, b_taus, 0)), 0

def _householder_product_lowering_rule(ctx, a, taus):
aval_out, = ctx.avals_out
Expand Down Expand Up @@ -2865,7 +2865,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return hessenberg(x), 0
return hessenberg(x), (0, 0)

batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule

Expand Down Expand Up @@ -2961,7 +2961,7 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal(x, lower=lower), 0
return tridiagonal_p.bind(x, lower=lower), (0, 0, 0, 0, 0)

batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule

Expand Down
8 changes: 8 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,10 @@ def testHessenberg(self, shape, dtype, calc_q):
check_dtypes=not calc_q)
self._CompileAndCheck(jsp_func, args_maker)

if len(shape) == 3:
args = args_maker()
self.assertAllClose(jax.vmap(jsp_func)(*args), jsp_func(*args))

@jtu.sample_product(
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
dtype=float_types + complex_types,
Expand Down Expand Up @@ -1798,6 +1802,10 @@ def sp_func(a):
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
check_dtypes=False)

if len(shape) == 3:
args = args_maker()
self.assertAllClose(jax.vmap(jax_func)(*args), jax_func(*args))

@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
dtype=float_types + complex_types,
Expand Down

0 comments on commit ea4e324

Please sign in to comment.