Skip to content

Commit

Permalink
Merge pull request #26518 from superbobry:maint-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726663977
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
2 parents f0cd168 + a73456d commit 60dcded
Show file tree
Hide file tree
Showing 28 changed files with 87 additions and 87 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,7 @@ def device_put(
assert not m and not d
copy_semantics.append(dispatch.CopySemantics.COPY)

for xf, d in zip(x_flat, device_flat): # type: ignore
for xf, d in zip(x_flat, device_flat):
_check_sharding(shaped_abstractify(xf), d)
out_flat = dispatch.device_put_p.bind(
*x_flat, devices=device_flat, srcs=src_flat,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding,
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
arrays = self._check_and_rearrange(arrays, self._sharding, self.aval)
self._arrays = arrays # type: ignore
self._arrays = arrays

if xla_extension_version >= 310:
def _check_and_rearrange(self, arrays, sharding, aval):
Expand Down Expand Up @@ -654,7 +654,7 @@ def block_until_ready(self):
if xla_extension_version >= 314:
@use_cpp_method()
def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore
... # type: ignore
... # pytype: disable=bad-return-type

else:
@use_cpp_method()
Expand Down Expand Up @@ -782,7 +782,7 @@ def make_array_from_callback(
raise TypeError(
"`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
f" layout when calling `jax.make_array_from_callback`. Got {sharding}")
sharding = sharding.sharding if isinstance(sharding, Layout) else sharding # type: ignore
sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
if not isinstance(sharding, Sharding):
raise TypeError(
f"sharding should be an instance of `jax.sharding`. Got {sharding} of"
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _get_batched_exception(self) -> BatchedError | None:
cur_effect = None
for error_effect, code in self._code.items():
if self._pred[error_effect][idx]: # type: ignore
if min_code is None or code[idx] < min_code: # type: ignore
if min_code is None or code[idx] < min_code:
min_code = code[idx] # type: ignore
cur_effect = error_effect

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def cloud_tpu_init() -> None:
def is_cloud_tpu_older_than(year: int, month: int, day: int):
# We import locally because the functions above must run before the runtime
# modules are imported.
from jax._src import xla_bridge # type: ignore
from jax._src import xla_bridge # pytype: disable=import-error
date = datetime.date(year, month, day)
if not running_in_cloud_tpu_vm:
return False
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,11 +1767,11 @@ def canonicalize_value(val):
return val

cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh == aval.sharding.mesh: # type: ignore
if cur_mesh == aval.sharding.mesh:
return val
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: # type: ignore
from jax._src.pjit import mesh_cast # type: ignore
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore
if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto:
from jax._src.pjit import mesh_cast # pytype: disable=import-error
return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim)))
return val


Expand All @@ -1785,7 +1785,7 @@ def get_cur_mesh_sharding(spec=None):
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def modify_spec_for_auto_manual(spec, mesh) -> P:
new_spec = [] # type: ignore
new_spec = []
for s in spec:
if not s:
new_spec.append(s)
Expand Down Expand Up @@ -1888,7 +1888,7 @@ def str_short(self, short_dtypes=False):
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def new_arg(trace, primal_aval, nz):
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment]
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
primal_trace.invalidate()
num_residuals = len(tangent_consts)
Expand Down Expand Up @@ -213,7 +213,7 @@ def direct_linearize(traceable: lu.WrappedFun,
del linearize_trace, ans, tracers
out_nzs = [type(t) is not Zero for t in out_tangents]
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore
out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info)
tangent_trace.invalidate()
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else
Expand Down Expand Up @@ -343,7 +343,7 @@ def write_primal(v, val):
ct_out = core.freeze(ref)
write_cotangent(eqn.primitive, val_var, ct_out)
elif eqn.primitive is core.freeze_p:
val_var, = eqn.outvars # type: ignore
val_var, = eqn.outvars
ref_var, = eqn.invars # type: ignore
ct_in = instantiate_zeros(read_cotangent(val_var))
write_primal(ref_var, core.mutable_array(ct_in))
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace,
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
debug_info)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore
jaxpr, constvals = _inline_literals(jaxpr, constvals)
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
set_states(self.attrs_tracked, self.attrs_inits)
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
Expand All @@ -1707,7 +1707,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
jaxpr_effects, debug_info)
# We can't run check_jaxpr until after we normalize.
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore[assignment]
jaxpr, constvals = _inline_literals(jaxpr, constvals)
jaxpr, out_type = _add_implicit_outputs(jaxpr)
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, out_type, constvals
Expand Down Expand Up @@ -2065,7 +2065,7 @@ def fwd_jaxpr_from_zeros(*zeros):
constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars,
prim.initial_style, # type: ignore[attribute-error]
prim.initial_style, # pytype: disable=attribute-error
dict(fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_from_zeros,
num_consts=len(consts),
Expand Down Expand Up @@ -2201,7 +2201,7 @@ def _check_no_returned_refs(
origin_info = ('\n\nThe returned mutable array was created on line '
f'{source_info_util.summarize(eqn.source_info)}.')
elif v in frame.invars:
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore
arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)]
origin_info = ('\n\nThe returned mutable array was passed in as the '
f'argument {arg_name}.')
else:
Expand Down
16 changes: 8 additions & 8 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,7 +2275,7 @@ def lower_sharding_computation(
# I refactor, this will also work well with mesh being provided at
# compile time.
# Sets device_assignment to None if only abstractMesh and unspecified exists.
num_devices, device_assignment = _get_num_devices( # type: ignore
num_devices, device_assignment = _get_num_devices(
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings),
device_assignment)
Expand Down Expand Up @@ -2330,12 +2330,12 @@ def lower_sharding_computation(
"mesh should be the same across the entire program. Got mesh"
f" shape for one sharding {abstract_mesh} and"
f" {sharding.mesh.abstract_mesh} for another")
abstract_mesh = sharding.mesh.abstract_mesh # type: ignore
abstract_mesh = sharding.mesh.abstract_mesh

semantic_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
in_shardings, global_in_avals)
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
out_shardings, global_out_avals)

(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
Expand Down Expand Up @@ -2540,7 +2540,7 @@ def _gspmd_to_named_sharding(
assert isinstance(orig_in_s, NamedSharding)
assert isinstance(orig_in_s.mesh, Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding # type: ignore
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding

def _gspmd_to_positional_sharding(
out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding:
Expand Down Expand Up @@ -2832,7 +2832,7 @@ def _maybe_get_and_check_out_shardings(
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
try:
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # pytype: disable=wrong-arg-types
except:
new_out_shardings.append(xla_s)
else:
Expand Down Expand Up @@ -3015,9 +3015,9 @@ def from_hlo(name: str,
device_assignment=da,
backend=backend,
input_avals=global_in_avals,
input_shardings=in_shardings, # type: ignore
input_shardings=in_shardings,
output_avals=global_out_avals,
output_shardings=out_shardings, # type: ignore # arg-type
output_shardings=out_shardings,
committed=committed,
name=name,
unordered_effects=unordered_effects,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
# Since not all inputs are used in jaxpr_unknown, we filter the input tracers
# down using the output of `dce_jaxpr`.
used_and_known = map(operator.and_, used_refs, map(operator.not_, in_unknowns))
tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k # type: ignore
tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k
in zip(tracers, used_and_known)]
_, known_used = partition_list(used_refs, used_and_known)
_, used_tracers = partition_list(used_refs, tracers)
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def asarray(x: ArrayLike) -> Array:
if isinstance(x, Array):
return x
elif isinstance(x, (bool, np.ndarray, np.generic)):
return _convert_element_type(x, weak_type=False) # type: ignore[bad-return-type]
return _convert_element_type(x, weak_type=False) # pytype: disable=bad-return-type
elif isinstance(x, (int, float, builtins.complex)):
return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) # type: ignore[bad-return-type]
return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True)
else:
raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.")

Expand Down Expand Up @@ -3212,7 +3212,7 @@ def _nary_lower_hlo(op: Callable, ctx,
"""
del params
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) # type: ignore
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape)
if config.sharding_in_types.value:
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)

Expand Down Expand Up @@ -7423,7 +7423,7 @@ def _zero(x):
if config.sharding_in_types.value:
x_aval = core.get_aval(x)
return full_like(x, shape=(), fill_value=0,
sharding=x_aval.sharding.with_spec(P())) # type: ignore
sharding=x_aval.sharding.with_spec(P()))
return full_like(x, shape=(), fill_value=0)

_ones: Callable = partial(full_like, fill_value=1)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
for a in in_avals:
if a is core.abstract_token:
continue
if a.sharding.mesh.empty: # type: ignore
if a.sharding.mesh.empty:
continue
if m is not None and m != a.sharding.mesh:
if m._are_all_axes_auto and a.sharding.mesh._are_all_axes_auto:
return mesh_lib.empty_abstract_mesh
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
f' another mesh: {a.sharding.mesh}')
m = a.sharding.mesh # type: ignore
m = a.sharding.mesh
return mesh_lib.empty_abstract_mesh if m is None else m


Expand Down
4 changes: 2 additions & 2 deletions jax/_src/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,10 @@ def _generate_logical_mesh(
zip(logical_indices, physical_indices, range(len(logical_indices)))
)
)
logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2
logical_mesh = np.transpose(logical_mesh, transpose_axes)

# Reshape to add the trivial dimensions back.
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)

return logical_mesh

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def get(self, *, indices_are_sorted=False, unique_indices=False,
if out_sharding is not None:
assert isinstance(out_sharding, (NamedSharding, PartitionSpec))
out_sharding = canonicalize_sharding(out_sharding)
take = auto_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore
take = auto_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names,
out_shardings=out_sharding.spec)
return take(self.array, self.index)

Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3195,7 +3195,7 @@ def _split(op: str, ary: ArrayLike,
sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r)
else:
raise ValueError(f"array split does not result in an equal division: rest is {r}")
sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc]
sizes = [i if core.is_symbolic_dim(i) else np.int64(i)
for i in sizes]
return list(lax.split(ary, sizes, axis=axis))

Expand Down Expand Up @@ -5498,7 +5498,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
sharding = object.aval.sharding
sharding = None if sharding.mesh.empty else sharding
else:
sharding = canonicalize_device_to_sharding(device) # type: ignore
sharding = canonicalize_device_to_sharding(device)

# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
Expand Down Expand Up @@ -6388,7 +6388,7 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
return lax.iota(dtype, start) # type: ignore[arg-type]
else:
if step is None and start == 0 and stop is not None:
return lax.iota(dtype, np.ceil(stop).astype(int)) # type: ignore[arg-type]
return lax.iota(dtype, np.ceil(stop).astype(int))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))


Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]:
return [lax.asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]


Expand All @@ -93,7 +93,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]:
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment]
to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True)
to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type]
return [lax._convert_element_type(x, to_dtype_inexact, weak_type)
for x in args]
Expand Down Expand Up @@ -250,7 +250,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
result_sharding = (lax.broadcast_shardings(*avals) # type: ignore
result_sharding = (lax.broadcast_shardings(*avals)
if config.sharding_in_types.value else None)
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/op_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices(

for i, idxs in enumerate(itertools.product(*axis_indices)):
for _ in range(num_replicas):
indices[next(device_it)] = idxs # type: ignore # numpy 2.2
indices[next(device_it)] = idxs
return indices


Expand Down
12 changes: 6 additions & 6 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def to_block_mapping(
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func, debug_info=debug), index_map_tree)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug and debug.func_src_info # type: ignore
None, debug and debug.func_src_info
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
Expand Down Expand Up @@ -883,7 +883,7 @@ def get_grid_mapping(
) -> tuple[tuple[jax_core.AbstractValue, ...],
GridMapping]:
if dynamic_shapes_export_enabled():
dim_check : Any = jax_core.is_dim # type: ignore[no-redef]
dim_check : Any = jax_core.is_dim
else:
dim_check : Any = jax_core.is_constant_dim # type: ignore[no-redef]
assert all(i is None or dim_check(i) for i in grid_spec.grid)
Expand Down Expand Up @@ -978,7 +978,7 @@ def get_grid_mapping(
grid=grid_mapping_grid, # type: ignore[arg-type]
grid_names=grid_spec.grid_names,
block_mappings=(*in_block_mappings, *out_block_mappings),
index_map_avals=index_map_avals, # type: ignore[arg-type]
index_map_avals=index_map_avals,
index_map_tree=index_map_tree,
vmapped_dims=(),
num_index_operands=num_flat_scalar_prefetch,
Expand All @@ -1002,14 +1002,14 @@ def get_grid_mapping(
def unzip_dynamic_grid_bounds(
grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]:
if dynamic_shapes_export_enabled():
new_grid : Any = grid_spec.grid # type: ignore[no-redef]
new_grid : Any = grid_spec.grid
else:
new_grid : Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) # type: ignore[no-redef]
dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int))
# We can't use dataclasses.replace, because our fields are incompatible
# with __init__'s signature.
static_self = copy.copy(grid_spec)
static_self.grid = new_grid # type: ignore
static_self.grid = new_grid
return static_self, dynamic_bounds


Expand Down Expand Up @@ -1188,6 +1188,6 @@ def lower_as_mlir(
) -> mlir.ir.Module:
with pallas_export_experimental(dynamic_shapes):
lowered = jax.jit(f, device=device).lower(*args, **kwargs)
stablehlo = lowered.compiler_ir(dialect="stablehlo") # type: ignore[return-value]
stablehlo = lowered.compiler_ir(dialect="stablehlo")

return stablehlo # type: ignore[return-value]
Loading

0 comments on commit 60dcded

Please sign in to comment.