Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

insane shmap autodiff bug + shit + efficient transpose + eager shmap #26223

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:

sharding_in_types = bool_state(
name='jax_sharding_in_types',
default=False,
default=True,
help=('When True, enables forward only sharding propagation in JAX and '
'avals have sharding on them.'),
include_in_jit_key=True)
Expand Down Expand Up @@ -1386,7 +1386,7 @@ def _update_disable_jit_thread_local(val):
name = 'jax_traceback_filtering',
enum_values=["off", "tracebackhide", "remove_frames", "quiet_remove_frames",
"auto"],
default="auto",
default="off",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
Expand Down
19 changes: 10 additions & 9 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class JaxprEqn:

def __init__(self, invars, outvars, primitive, params, effects, source_info,
ctx):
# if primitive.name == 'pjit' and not all(i.aval == o.aval for i, o in zip(params['jaxpr'].jaxpr.invars, invars)):
# breakpoint()
self.invars = invars
self.outvars = outvars
self.primitive = primitive
Expand Down Expand Up @@ -610,10 +612,10 @@ def check_avals_context_mesh(avals, prim_name):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_abstract_mesh()
for a in avals:
if a.sharding.mesh.empty or cur_mesh.empty:
continue
if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto:
continue
# if a.sharding.mesh.empty or cur_mesh.empty:
# continue
# if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto:
# continue
if a.sharding.mesh != cur_mesh:
raise ValueError(
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
Expand Down Expand Up @@ -1814,11 +1816,10 @@ def get_sharding(sharding, ndim):
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
else:
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh.empty:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
# if cur_mesh.empty:
# raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
out_s = NamedSharding(cur_mesh, P(*[None] * ndim))
out_s = NamedSharding(mesh_lib.empty_abstract_mesh, P(*[None] * ndim))
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
Expand Down Expand Up @@ -1884,7 +1885,7 @@ def str_short(self, short_dtypes=False):
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
return f'{dt_str}[{shapestr}]'
return f'{dt_str}[{shapestr}]({self.sharding.mesh.axis_types})'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,13 @@ class JVPTracer(Tracer):
__slots__ = ['primal', 'tangent']

def __init__(self, trace, primal, tangent):
if config.enable_checks.value:
_primal_tangent_shapes_match(primal, tangent)
# if config.enable_checks.value:
_primal_tangent_shapes_match(primal, tangent)
self._trace = trace
self.primal = primal
self.tangent = tangent
# if not isinstance(self.primal, Tracer) and self.primal is 0.2:
# breakpoint()

@property
def aval(self):
Expand All @@ -569,11 +571,9 @@ def get_referent(self):

def _primal_tangent_shapes_match(primal, tangent):
if type(tangent) is not Zero:
primal_aval = get_aval(primal).strip_weak_type()
expected_tangent_aval = get_aval(primal).strip_weak_type().to_tangent_aval()
tangent_aval = get_aval(tangent).strip_weak_type()
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape)
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
assert tangent_aval == expected_tangent_aval, breakpoint()

call_param_updaters: dict[core.Primitive, Callable] = {}
call_linearize_param_updaters: dict[core.Primitive, Callable] = {}
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,8 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)
# if primitive.name == 'pjit' and not all(i.aval == o.aval for i, o in zip(params['jaxpr'].jaxpr.invars, in_tracers)):
# breakpoint()
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
out_avals, primitive, params, effects, source_info,
ctx)
Expand Down Expand Up @@ -2522,6 +2524,9 @@ def inline_jaxpr_into_trace(
const_tracers = map(trace.new_const, consts)
constvars = map(trace.getvar, const_tracers)
argvars = map(trace.getvar, arg_tracers)
# if config.enable_checks.value:
# assert all(arg_v.aval == j_v.aval
# for arg_v, j_v in zip(argvars, jaxpr.invars)), breakpoint()
env: dict[Var, Var] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
[*constvars, *argvars]))

Expand Down
2 changes: 2 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4992,6 +4992,8 @@ def _split_on_one_axis(op_shape, new_sizes, name):
return False, []
i, j, count, out = 0, 0, 0, []
while j < len(new_sizes):
try: op_shape[i] == new_sizes[j]
except IndexError: return False, []
if op_shape[i] == new_sizes[j]:
out.append(op_shape[i])
else:
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P

zip, unsafe_zip = safe_zip, zip

Expand Down Expand Up @@ -51,7 +52,8 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
return None if num_out is None else [None] * num_out
s = NamedSharding(cur_mesh, P())
return s if num_out is None else [s] * num_out
if rule is None:
raise ValueError(
f'sharding rule for {prim.name} is not implemented. Please file a'
Expand All @@ -68,12 +70,13 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
weak_type = weak_type_rule(*avals, **kwargs)
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ShapedArray:
avals = core.cast_from_auto_to_manual(avals)
# avals = core.cast_from_auto_to_manual(avals)
core.check_avals_context_mesh(avals, prim.name)
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=call_sharding_rule(prim, sharding_rule, None, *avals, **kwargs))
if str(prim) == 'mul': breakpoint()
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
elif least_specialized is core.DShapedArray:
Expand Down
13 changes: 11 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,6 +2078,8 @@ def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
# if not all(core.get_aval(p).to_tangent_aval() == t.aval == a for p, t, a in zip(primals_in, tangents_in, jaxpr.in_avals)):
# breakpoint()
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
Expand Down Expand Up @@ -2115,8 +2117,13 @@ def _filter_zeros(is_nz_l, l):
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
tangents_out_it = iter(tangents_out)
return primals_out, [next(tangents_out_it) if nz else ad.Zero(aval)
for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)]
tangents_out = [next(tangents_out_it) if nz else ad.Zero(aval)
for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)]
for p, t in zip(primals_out, tangents_out):
expected_tangent_aval = core.get_aval(p).strip_weak_type().to_tangent_aval()
tangent_aval = core.get_aval(t).strip_weak_type()
assert tangent_aval == expected_tangent_aval, breakpoint()
return primals_out, tangents_out
ad.primitive_jvps[pjit_p] = _pjit_jvp


Expand Down Expand Up @@ -2175,6 +2182,8 @@ def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
name, keep_unused, inline, compiler_options_kvs):
# if not all(i.aval == o.aval for i, o in zip(jaxpr.jaxpr.invars, in_tracers)):
# breakpoint()
in_pvals = [t.pval for t in in_tracers]

known_ins = tuple(pv.is_known() for pv in in_pvals)
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,14 +1026,24 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):

def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
del args
# cts = [cast_if_necessary(x) for x in cts]
return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
ad.deflinear2(psum2_p, _psum2_transpose_rule)

# def cast_if_necessary(x):
# aval = core.get_aval(x)
# cur_mesh = get_abstract_mesh()
# print('#here', cur_mesh, aval)
# if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto:
# return pjit.mesh_cast(x, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
# return x

# pbroadcast_p is exactly the transpose of psum2_p
def pbroadcast(x, axis_name):
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
if not axis_name: return x
xs, treedef = tree_flatten(x)
# xs = [cast_if_necessary(x) for x in xs]
ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None)
return tree_unflatten(treedef, ys)
pbroadcast_p = core.Primitive('pbroadcast')
Expand Down
27 changes: 19 additions & 8 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src import pjit
import jax.numpy as jnp

from jax.experimental.custom_partitioning import custom_partitioning
Expand Down Expand Up @@ -2762,24 +2763,34 @@ def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
# @parameterized.named_parameters(
# sample(jtu.NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)
def test_grads_closure(self):
mesh = jtu.create_mesh((1, 1), ('i', 'j'))
fun = jnp.dot
in_specs = (P(None, None), P(None, None, ('i',)))
out_specs = P(('i',), None, None)
args = [np.arange(math.prod(s), dtype=np.float32).reshape(s)
for s in [(2, 3), (2, 3, 1)]]

no_sharding = [all(elt is None for elt in spec) for spec in in_specs]
args, closed_over_args = partition_list(no_sharding, args)
in_specs, _ = partition_list(no_sharding, in_specs)

def mesh_cast(x):
from jax._src import mesh as mesh_lib
mesh = mesh_lib.get_abstract_mesh()
return pjit.mesh_cast(x, NamedSharding(mesh, P()))

def f(x, *closed_over_args):
@partial(shard_map, mesh=mesh, in_specs=(*in_specs,), out_specs=out_specs)
def g(*args):
args = [x * arg for arg in args]
args = [mesh_cast(x) * arg for arg in args]
args = merge_lists(no_sharding, args, closed_over_args)
return fun(*args)
if jit:
g = jax.jit(g)
return g(*args)
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
jtu.check_grads(f, (0.2, *closed_over_args), modes=('rev',), order=2, atol=1e-2, rtol=1e-2)

@parameterized.named_parameters(
sample(jtu.NUM_GENERATED_CASES.value,
Expand Down
Loading