From a73456d54d94df974002dba9c04f7f5a62be16b0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 13 Feb 2025 18:05:27 +0000 Subject: [PATCH] Removed unused ``# type: ignore`` comments For future reference, this can be done via python -m mypy jax --warn-unused-ignores > /tmp/unused.txt while IFS=: read file line rest; do echo "$file:$line"; gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file" done < /tmp/unused.txt --- jax/_src/api.py | 2 +- jax/_src/array.py | 6 ++--- jax/_src/checkify.py | 2 +- jax/_src/cloud_tpu_init.py | 2 +- jax/_src/core.py | 12 +++++----- jax/_src/interpreters/ad.py | 6 ++--- jax/_src/interpreters/partial_eval.py | 8 +++---- jax/_src/interpreters/pxla.py | 16 ++++++------- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/lax.py | 8 +++---- jax/_src/lax/utils.py | 4 ++-- jax/_src/mesh_utils.py | 4 ++-- jax/_src/numpy/array_methods.py | 2 +- jax/_src/numpy/lax_numpy.py | 6 ++--- jax/_src/numpy/util.py | 6 ++--- jax/_src/op_shardings.py | 2 +- jax/_src/pallas/core.py | 12 +++++----- jax/_src/pallas/hlo_interpreter.py | 2 +- jax/_src/pallas/mosaic/interpret.py | 2 +- jax/_src/pallas/pallas_call.py | 6 ++--- .../pallas/triton/pallas_call_registration.py | 2 +- jax/_src/pjit.py | 24 +++++++++---------- jax/_src/sharding_impls.py | 4 ++-- jax/_src/sharding_specs.py | 2 +- jax/_src/state/discharge.py | 2 +- jax/_src/tree_util.py | 20 ++++++++-------- .../jax2tf/tests/model_harness.py | 8 +++---- jax/experimental/shard_map.py | 2 +- 28 files changed, 87 insertions(+), 87 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 21eba48c6a0b..8d47a0e968d6 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2357,7 +2357,7 @@ def device_put( assert not m and not d copy_semantics.append(dispatch.CopySemantics.COPY) - for xf, d in zip(x_flat, device_flat): # type: ignore + for xf, d in zip(x_flat, device_flat): _check_sharding(shaped_abstractify(xf), d) out_flat = dispatch.device_put_p.bind( *x_flat, devices=device_flat, srcs=src_flat, diff --git a/jax/_src/array.py b/jax/_src/array.py index b2de0da5d96e..4ab48ab5ae4a 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -209,7 +209,7 @@ def __init__(self, aval: core.ShapedArray, sharding: Sharding, # (like pjit, etc). if not _skip_checks or config.enable_checks.value: arrays = self._check_and_rearrange(arrays, self._sharding, self.aval) - self._arrays = arrays # type: ignore + self._arrays = arrays if xla_extension_version >= 310: def _check_and_rearrange(self, arrays, sharding, aval): @@ -654,7 +654,7 @@ def block_until_ready(self): if xla_extension_version >= 314: @use_cpp_method() def _single_device_array_to_np_array_did_copy(self) -> tuple[np.ndarray, bool]: # type: ignore - ... # type: ignore + ... else: @use_cpp_method() @@ -782,7 +782,7 @@ def make_array_from_callback( raise TypeError( "`DeviceLocalLayout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Layout) else sharding # type: ignore + sharding = sharding.sharding if isinstance(sharding, Layout) else sharding if not isinstance(sharding, Sharding): raise TypeError( f"sharding should be an instance of `jax.sharding`. Got {sharding} of" diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 831c4488fc46..a77a6456c3b7 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -260,7 +260,7 @@ def _get_batched_exception(self) -> BatchedError | None: cur_effect = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore - if min_code is None or code[idx] < min_code: # type: ignore + if min_code is None or code[idx] < min_code: min_code = code[idx] # type: ignore cur_effect = error_effect diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 826d14b5b6e3..2d0303358be7 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -105,7 +105,7 @@ def cloud_tpu_init() -> None: def is_cloud_tpu_older_than(year: int, month: int, day: int): # We import locally because the functions above must run before the runtime # modules are imported. - from jax._src import xla_bridge # type: ignore + from jax._src import xla_bridge date = datetime.date(year, month, day) if not running_in_cloud_tpu_vm: return False diff --git a/jax/_src/core.py b/jax/_src/core.py index 3757e3d52e8b..9f03ab4632d1 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1767,11 +1767,11 @@ def canonicalize_value(val): return val cur_mesh = mesh_lib.get_abstract_mesh() - if cur_mesh == aval.sharding.mesh: # type: ignore + if cur_mesh == aval.sharding.mesh: return val - if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: # type: ignore - from jax._src.pjit import mesh_cast # type: ignore - return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) # type: ignore + if cur_mesh._are_all_axes_manual and aval.sharding.mesh._are_all_axes_auto: + from jax._src.pjit import mesh_cast + return mesh_cast(val, NamedSharding(cur_mesh, P(*[None] * aval.ndim))) return val @@ -1785,7 +1785,7 @@ def get_cur_mesh_sharding(spec=None): # TODO(yashkatariya): Only works with User/Auto. Generalize it to work with # Collective too. def modify_spec_for_auto_manual(spec, mesh) -> P: - new_spec = [] # type: ignore + new_spec = [] for s in spec: if not s: new_spec.append(s) @@ -1888,7 +1888,7 @@ def str_short(self, short_dtypes=False): self.dtype.name) dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: - shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore + shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) return f'{dt_str}[{shapestr}]' else: shapestr = ','.join(map(str, self.shape)) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ac3c03933795..ec33fb02e92c 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -180,7 +180,7 @@ def new_arg(trace, primal_aval, nz): if attrs_tracked: raise NotImplementedError("TODO: attrs") residuals_and_primals = (*tangent_consts, *out_primals) - residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) # type: ignore[assignment] + residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals) primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info) primal_trace.invalidate() num_residuals = len(tangent_consts) @@ -213,7 +213,7 @@ def direct_linearize(traceable: lu.WrappedFun, del linearize_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] - out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) # type: ignore + out_nz_tangents = map(tangent_trace.to_jaxpr_tracer, out_nz_tangents) jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_nz_tangents, traceable.debug_info) tangent_trace.invalidate() out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) if nz else @@ -343,7 +343,7 @@ def write_primal(v, val): ct_out = core.freeze(ref) write_cotangent(eqn.primitive, val_var, ct_out) elif eqn.primitive is core.freeze_p: - val_var, = eqn.outvars # type: ignore + val_var, = eqn.outvars ref_var, = eqn.invars # type: ignore ct_in = instantiate_zeros(read_cotangent(val_var)) write_primal(ref_var, core.mutable_array(ct_in)) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ee3df03d5ed3..9fa186308f0f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1690,7 +1690,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects, debug_info) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore + jaxpr, constvals = _inline_literals(jaxpr, constvals) init_trees = [tree_structure(init_val) for init_val in self.attrs_inits] set_states(self.attrs_tracked, self.attrs_inits) return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) @@ -1707,7 +1707,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], jaxpr_effects, debug_info) # We can't run check_jaxpr until after we normalize. jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) - jaxpr, constvals = _inline_literals(jaxpr, constvals) # type: ignore[assignment] + jaxpr, constvals = _inline_literals(jaxpr, constvals) jaxpr, out_type = _add_implicit_outputs(jaxpr) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, out_type, constvals @@ -2065,7 +2065,7 @@ def fwd_jaxpr_from_zeros(*zeros): constvars = map(self.getvar, map(self.to_jaxpr_tracer, consts)) outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, - prim.initial_style, # type: ignore[attribute-error] + prim.initial_style, dict(fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_from_zeros, num_consts=len(consts), @@ -2201,7 +2201,7 @@ def _check_no_returned_refs( origin_info = ('\n\nThe returned mutable array was created on line ' f'{source_info_util.summarize(eqn.source_info)}.') elif v in frame.invars: - arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] # type: ignore + arg_name = dbg.safe_arg_names(len(frame.invars))[frame.invars.index(v)] origin_info = ('\n\nThe returned mutable array was passed in as the ' f'argument {arg_name}.') else: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e408e4c0ed91..181995222883 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2275,7 +2275,7 @@ def lower_sharding_computation( # I refactor, this will also work well with mesh being provided at # compile time. # Sets device_assignment to None if only abstractMesh and unspecified exists. - num_devices, device_assignment = _get_num_devices( # type: ignore + num_devices, device_assignment = _get_num_devices( it.chain(unique_in_shardings, unique_out_shardings, unique_intermediate_shardings), device_assignment) @@ -2330,12 +2330,12 @@ def lower_sharding_computation( "mesh should be the same across the entire program. Got mesh" f" shape for one sharding {abstract_mesh} and" f" {sharding.mesh.abstract_mesh} for another") - abstract_mesh = sharding.mesh.abstract_mesh # type: ignore + abstract_mesh = sharding.mesh.abstract_mesh semantic_in_shardings = SemanticallyEqualShardings( - in_shardings, global_in_avals) # type: ignore + in_shardings, global_in_avals) semantic_out_shardings = SemanticallyEqualShardings( - out_shardings, global_out_avals) # type: ignore + out_shardings, global_out_avals) (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( @@ -2540,7 +2540,7 @@ def _gspmd_to_named_sharding( assert isinstance(orig_in_s, NamedSharding) assert isinstance(orig_in_s.mesh, Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) -_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding # type: ignore +_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding def _gspmd_to_positional_sharding( out_s: GSPMDSharding, orig_in_s: PositionalSharding) -> PositionalSharding: @@ -2832,7 +2832,7 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) try: - new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) except: new_out_shardings.append(xla_s) else: @@ -3015,9 +3015,9 @@ def from_hlo(name: str, device_assignment=da, backend=backend, input_avals=global_in_avals, - input_shardings=in_shardings, # type: ignore + input_shardings=in_shardings, output_avals=global_out_avals, - output_shardings=out_shardings, # type: ignore # arg-type + output_shardings=out_shardings, # arg-type committed=committed, name=name, unordered_effects=unordered_effects, diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 52c6fd97a1da..b6966cf18c29 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -481,7 +481,7 @@ def _for_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer, # Since not all inputs are used in jaxpr_unknown, we filter the input tracers # down using the output of `dce_jaxpr`. used_and_known = map(operator.and_, used_refs, map(operator.not_, in_unknowns)) - tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k # type: ignore + tracers = [trace.instantiate_const(t) if u_and_k else t for t, u_and_k in zip(tracers, used_and_known)] _, known_used = partition_list(used_refs, used_and_known) _, used_tracers = partition_list(used_refs, tracers) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0c6f873b35fb..19f91899f741 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -138,9 +138,9 @@ def asarray(x: ArrayLike) -> Array: if isinstance(x, Array): return x elif isinstance(x, (bool, np.ndarray, np.generic)): - return _convert_element_type(x, weak_type=False) # type: ignore[bad-return-type] + return _convert_element_type(x, weak_type=False) elif isinstance(x, (int, float, builtins.complex)): - return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) # type: ignore[bad-return-type] + return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) else: raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") @@ -3212,7 +3212,7 @@ def _nary_lower_hlo(op: Callable, ctx, """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out - args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) # type: ignore + args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) if config.sharding_in_types.value: args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) @@ -7423,7 +7423,7 @@ def _zero(x): if config.sharding_in_types.value: x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) # type: ignore + sharding=x_aval.sharding.with_spec(P())) return full_like(x, shape=(), fill_value=0) _ones: Callable = partial(full_like, fill_value=1) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 35ae9d49a463..44760cffd58d 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -56,7 +56,7 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: for a in in_avals: if a is core.abstract_token: continue - if a.sharding.mesh.empty: # type: ignore + if a.sharding.mesh.empty: continue if m is not None and m != a.sharding.mesh: if m._are_all_axes_auto and a.sharding.mesh._are_all_axes_auto: @@ -64,7 +64,7 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {m} and' f' another mesh: {a.sharding.mesh}') - m = a.sharding.mesh # type: ignore + m = a.sharding.mesh return mesh_lib.empty_abstract_mesh if m is None else m diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index 89ae8b505558..ec0319a83753 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -597,10 +597,10 @@ def _generate_logical_mesh( zip(logical_indices, physical_indices, range(len(logical_indices))) ) ) - logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2 + logical_mesh = np.transpose(logical_mesh, transpose_axes) # numpy 2.2 # Reshape to add the trivial dimensions back. - logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2 + logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # numpy 2.2 return logical_mesh diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 63a7557a5553..409e67194f0a 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -784,7 +784,7 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, if out_sharding is not None: assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) out_sharding = canonicalize_sharding(out_sharding) - take = auto_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore + take = auto_axes(take, axes=mesh_lib.get_abstract_mesh().axis_names, out_shardings=out_sharding.spec) return take(self.array, self.index) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0d82481974c2..a79a652cb2e7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3195,7 +3195,7 @@ def _split(op: str, ary: ArrayLike, sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + sizes = [i if core.is_symbolic_dim(i) else np.int64(i) for i in sizes] return list(lax.split(ary, sizes, axis=axis)) @@ -5498,7 +5498,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, sharding = object.aval.sharding sharding = None if sharding.mesh.empty else sharding else: - sharding = canonicalize_device_to_sharding(device) # type: ignore + sharding = canonicalize_device_to_sharding(device) # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and @@ -6388,7 +6388,7 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, return lax.iota(dtype, start) # type: ignore[arg-type] else: if step is None and start == 0 and stop is not None: - return lax.iota(dtype, np.ceil(stop).astype(int)) # type: ignore[arg-type] + return lax.iota(dtype, np.ceil(stop).astype(int)) return array(np.arange(start, stop=stop, step=step, dtype=dtype)) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e902950e76dc..82170ae1e74f 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -84,7 +84,7 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]: return [lax.asarray(arg) for arg in args] else: to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] @@ -93,7 +93,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] + to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type] return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] @@ -250,7 +250,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes): return [lax.asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) - result_sharding = (lax.broadcast_shardings(*avals) # type: ignore + result_sharding = (lax.broadcast_shardings(*avals) if config.sharding_in_types.value else None) return [_broadcast_to(arg, result_shape, result_sharding) for arg in args] diff --git a/jax/_src/op_shardings.py b/jax/_src/op_shardings.py index 51c1cb35a500..b559ab9e7023 100644 --- a/jax/_src/op_shardings.py +++ b/jax/_src/op_shardings.py @@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices( for i, idxs in enumerate(itertools.product(*axis_indices)): for _ in range(num_replicas): - indices[next(device_it)] = idxs # type: ignore # numpy 2.2 + indices[next(device_it)] = idxs # numpy 2.2 return indices diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index f77c46ebe648..5858b041e360 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -419,7 +419,7 @@ def to_block_mapping( flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( lu.wrap_init(index_map_func, debug_info=debug), index_map_tree) index_map_src_info = NameAndSrcInfo.from_pallas_call( - None, debug and debug.func_src_info # type: ignore + None, debug and debug.func_src_info ) with tracing_grid_env(grid, mapped_dims): jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( @@ -883,7 +883,7 @@ def get_grid_mapping( ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): - dim_check : Any = jax_core.is_dim # type: ignore[no-redef] + dim_check : Any = jax_core.is_dim else: dim_check : Any = jax_core.is_constant_dim # type: ignore[no-redef] assert all(i is None or dim_check(i) for i in grid_spec.grid) @@ -978,7 +978,7 @@ def get_grid_mapping( grid=grid_mapping_grid, # type: ignore[arg-type] grid_names=grid_spec.grid_names, block_mappings=(*in_block_mappings, *out_block_mappings), - index_map_avals=index_map_avals, # type: ignore[arg-type] + index_map_avals=index_map_avals, index_map_tree=index_map_tree, vmapped_dims=(), num_index_operands=num_flat_scalar_prefetch, @@ -1002,14 +1002,14 @@ def get_grid_mapping( def unzip_dynamic_grid_bounds( grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]: if dynamic_shapes_export_enabled(): - new_grid : Any = grid_spec.grid # type: ignore[no-redef] + new_grid : Any = grid_spec.grid else: new_grid : Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) # type: ignore[no-redef] dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int)) # We can't use dataclasses.replace, because our fields are incompatible # with __init__'s signature. static_self = copy.copy(grid_spec) - static_self.grid = new_grid # type: ignore + static_self.grid = new_grid return static_self, dynamic_bounds @@ -1188,6 +1188,6 @@ def lower_as_mlir( ) -> mlir.ir.Module: with pallas_export_experimental(dynamic_shapes): lowered = jax.jit(f, device=device).lower(*args, **kwargs) - stablehlo = lowered.compiler_ir(dialect="stablehlo") # type: ignore[return-value] + stablehlo = lowered.compiler_ir(dialect="stablehlo") return stablehlo # type: ignore[return-value] diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index a21fc1fb06cf..fc78d6044172 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -347,7 +347,7 @@ def pallas_call_hlo_interpret( del compiler_params, cost_estimate, out_avals # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. - dynamic_grid_args, args = split_list( # type: ignore + dynamic_grid_args, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds] ) dynamic_grid_args_iter = iter(dynamic_grid_args) diff --git a/jax/_src/pallas/mosaic/interpret.py b/jax/_src/pallas/mosaic/interpret.py index 06e3705e4bc6..33fe0ae60c88 100644 --- a/jax/_src/pallas/mosaic/interpret.py +++ b/jax/_src/pallas/mosaic/interpret.py @@ -685,7 +685,7 @@ def interpret_pallas_call( del debug, cost_estimate, out_avals # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) - dynamic_grid_args, scalars, input_args = split_list( # type: ignore + dynamic_grid_args, scalars, input_args = split_list( args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands], ) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 7e0bf58300aa..a59f088db384 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -566,7 +566,7 @@ def get_size(i, x, d): ragged_axis_values.append(None) # type: ignore[arg-type] all_dims = list(dims) + [0] * grid_mapping.num_outputs - ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs # type: ignore[list-item] + ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands @@ -895,7 +895,7 @@ def index_rewrite_kernel(*indexer_args): if config.sharding_in_types.value else None) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) - batched_out_avals = tuple(batched_out_avals) # type: ignore + batched_out_avals = tuple(batched_out_avals) out = pallas_call_p.bind( *dynamic_grid_args, @@ -1037,7 +1037,7 @@ def pallas_call_checkify_rule(error: checkify.Error, # returning them, since pallas kernels do not return outputs. # 4) Create block specs for the error state and call pallas_call with # the new kernel. - dynamic_grid_bounds, scalars, args = split_list( # type: ignore + dynamic_grid_bounds, scalars, args = split_list( args, [grid_mapping.num_dynamic_grid_bounds, grid_mapping.num_index_operands] ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 8e141dfde046..7759fb4744ae 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -177,7 +177,7 @@ def pallas_call_lowering( # TODO(b/392558289): Migrate to ``jax.ffi``. return mlir.custom_call( call_target_name="triton_kernel_call", - result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], # type: ignore[list-item] + result_types=[*map(mlir.aval_to_ir_type, ctx.avals_out)], operands=in_nodes, backend_config=zlib.compress( kernel_call.to_proto( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index aca6fff3c1ab..1443c5d7c8a6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -738,18 +738,18 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg # type: ignore - else f"flattened argument number is {i}") # type: ignore + arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg + else f"flattened argument number is {i}") raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = (f"path {dbg.arg_names[i]}" if dbg # type: ignore - else f"flattened argument number {i}") # type: ignore + arg_description = (f"path {dbg.arg_names[i]}" if dbg + else f"flattened argument number {i}") raise TypeError( f"Error interpreting argument to {fun} as an abstract array." - f" The problematic value is of type {type(x)} and was passed to" # type: ignore + f" The problematic value is of type {type(x)} and was passed to" f" the function at {arg_description}.\n" "This typically means that a jit-wrapped function was called with a non-array" " argument, and this argument was not marked as static using the" @@ -1340,11 +1340,11 @@ def _check_and_canonicalize_out_shardings( if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_avals, - None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), "pjit outputs", allow_uneven_sharding=False) check_aval_layout_compatibility( out_layouts_flat, out_avals, - None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), "jit outputs") return out_shardings_flat, out_layouts_flat @@ -2238,7 +2238,7 @@ def keep_where(l, should_keep): unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_out_avals = unknown_jaxpr.out_avals unknown_tracers_out = [ - pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) # type: ignore + pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in unknown_out_avals ] eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers), @@ -2778,7 +2778,7 @@ def reshard(xs, out_shardings): out_flat = [] for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): ds = canonicalize_sharding(s) - ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # type: ignore + ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) out_flat.append(reshard_p.bind(x, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) @@ -2828,14 +2828,14 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None, axis_type: mesh_lib.AxisTypes): cur_mesh = mesh_lib.get_abstract_mesh() if axes is None: - axes = cur_mesh.axis_names # type: ignore + axes = cur_mesh.axis_names if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if cur_mesh._name_to_type[a] == axis_type: # type: ignore + if cur_mesh._name_to_type[a] == axis_type: raise ValueError(f'Axes {a} cannot be casted to type {axis_type} since ' f'it already is of type {axis_type}.') - new_mesh = cur_mesh.update_axis_types({axis_type: axes}) # type: ignore + new_mesh = cur_mesh.update_axis_types({axis_type: axes}) return new_mesh def auto_axes(fun, *, axes: str | tuple[str, ...] | None = None, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 9b2bffb69ca6..f48d9314e7d6 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -451,7 +451,7 @@ def __repr__(self) -> str: ids = self._ids.copy() platform_name = self._devices[0].platform.upper() for idx, x in np.ndenumerate(ids): - ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2 + ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # numpy 2.2 body = np.array2string(ids, prefix=cls_name + '(', suffix=')', max_line_width=100) mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}' @@ -1269,7 +1269,7 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, ' `jax.sharding.use_mesh` is not allowed. Please pass a' ' NamedSharding instance or enter into a mesh context via' f' `jax.sharding.use_mesh`. Got {sharding}') - sharding = NamedSharding(cur_mesh, sharding) # type: ignore + sharding = NamedSharding(cur_mesh, sharding) else: if (check_mesh_consistency and not cur_mesh.empty and sharding.mesh.abstract_mesh != cur_mesh): diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index 5b88b3b1ec99..ac5322438707 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -97,7 +97,7 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray: # is used to extract the corresponding shard of the logical array. shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_) for i, idxs in enumerate(itertools.product(*axis_indices)): - shard_indices[i] = idxs # type: ignore # numpy 2.2 + shard_indices[i] = idxs # numpy 2.2 shard_indices = shard_indices.reshape(shard_indices_shape) # Ensure that each sharded axis is used exactly once in the mesh mapping diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a15d9002f78c..dbfe46af4772 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -809,7 +809,7 @@ def _run_state_partial_eval_custom( jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout) # In a stateful partial_eval, the residuals should be `Ref`s. - res_avals = map(AbstractRef, res_avals) # type: ignore + res_avals = map(AbstractRef, res_avals) known_invars, staged_invars = partition_list(in_unknowns, eqn.invars) known_outvars, staged_outvars = partition_list(in_unknowns, eqn.outvars) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 1ead7be9adb2..6c7e15a042e5 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -287,14 +287,14 @@ def register_pytree_node( >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) """ - default_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + default_registry.register_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys_func ) - none_leaf_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + none_leaf_registry.register_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys_func ) - dispatch_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] + dispatch_registry.register_node( + nodetype, flatten_func, unflatten_func, flatten_with_keys_func ) _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) @@ -715,10 +715,10 @@ def _equality_errors(path, t1, t2, is_leaf): yield from _equality_errors((*path, k), c1, c2, is_leaf) -SequenceKey: Any = pytree.SequenceKey # type: ignore -DictKey: Any = pytree.DictKey # type: ignore -GetAttrKey: Any = pytree.GetAttrKey # type: ignore -FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore +SequenceKey: Any = pytree.SequenceKey +DictKey: Any = pytree.DictKey +GetAttrKey: Any = pytree.GetAttrKey +FlattenedIndexKey: Any = pytree.FlattenedIndexKey @export diff --git a/jax/experimental/jax2tf/tests/model_harness.py b/jax/experimental/jax2tf/tests/model_harness.py index 91aacf2f596f..ec038dd4effe 100644 --- a/jax/experimental/jax2tf/tests/model_harness.py +++ b/jax/experimental/jax2tf/tests/model_harness.py @@ -293,7 +293,7 @@ def _vae_harness(name, **kwargs): tensor_specs=tensor_specs) # bilstm input specs: [((2, 3), np.int32), ((2,), np.int32)] = [inputs, lengths] -for poly_shapes, tensor_specs in [ # type: ignore +for poly_shapes, tensor_specs in [ (None, None), # batch polymorphism (["(b, _)", "(_,)"], [((None, 3), tf.int32), ((2,), tf.int32)]), @@ -347,7 +347,7 @@ def _vae_harness(name, **kwargs): # ((1, 2, 4), np.float32), # encoder inp: [batch, max_input_len, vocab_size] # ((1, 3, 4), np.float32), # decoder_inp: [batch, max_output_len, vocab_size] # ] -for poly_shapes, tensor_specs in [ # type: ignore +for poly_shapes, tensor_specs in [ (None, None), # batch polymorphism ( @@ -372,7 +372,7 @@ def _vae_harness(name, **kwargs): tensor_specs=tensor_specs) # lm1b/nlp_seq input spec: [((2, 1), np.float32)] [batch, seq_len] -for poly_shapes, tensor_specs in [ # type: ignore +for poly_shapes, tensor_specs in [ (None, None), # batch polymorphism. (["(b, _)"], [((None, 1), tf.float32)]), @@ -392,7 +392,7 @@ def _vae_harness(name, **kwargs): # ((1, 2), np.float32), # inputs: [batch, max_target_len] # ((1, 2), np.float32), # targets: [batch, max_target_len] # ] -for poly_shapes, tensor_specs in [ # type: ignore +for poly_shapes, tensor_specs in [ (None, None), # batch polymorphism. (["(b, _)"] * 2, [((None, 1), tf.float32)] * 2), diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 3247e7ba7bac..3606a69d419e 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -977,7 +977,7 @@ def aval(self): out = core.mapped_aval(self._trace.mesh.size, 0, aval) if config.sharding_in_types.value: new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh), - out.sharding.spec) # type: ignore + out.sharding.spec) else: new_sharding = None return out.update(sharding=new_sharding)