diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index e0882fece29d..df545cb72072 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -33,6 +33,7 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -190,6 +191,13 @@ def get_compile_options( build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value build_options.memory_fitting_effort = config.memory_fitting_effort.value + if xla_extension_version >= 316: + build_options.optimization_level = config.EffortLevel( + config.optimization_level.value + ).value + build_options.memory_fitting_level = config.EffortLevel( + config.memory_fitting_level.value + ).value # This is a temporary workaround to simplify the AutoPGLE usage. # TODO(b/376647494): Remove once the bug is fixed. @@ -203,7 +211,12 @@ def get_compile_options( if env_options_overrides is not None: # Some overrides are passed directly on build_options. overrides_on_build_options = [ - 'exec_time_optimization_effort', 'memory_fitting_effort'] + "exec_time_optimization_effort", "memory_fitting_effort"] + if xla_extension_version >= 316: + overrides_on_build_options.extend( + ["optimization_level", "memory_fitting_level"] + ) + env_options_overrides = dict(env_options_overrides) for name in overrides_on_build_options: if name in env_options_overrides: diff --git a/jax/_src/config.py b/jax/_src/config.py index 7a9d84d23a17..73ec4c81556c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -16,6 +16,7 @@ from collections.abc import Callable, Iterator, Sequence import contextlib +import enum import functools import itertools import logging @@ -35,6 +36,29 @@ _T = TypeVar('_T') +class EffortLevel(enum.Enum): + """Effort level enum, mirroring the XLA effort options.""" + + UNKNOWN = 0 + O0 = 9 + O1 = 19 + O2 = 29 + O3 = 39 + + @classmethod + def _missing_(cls, value: object) -> EffortLevel | None: + return _effort_from_string.get(value) + + +_effort_from_string: dict[Any, EffortLevel] = { + 'UNKNOWN': EffortLevel.UNKNOWN, + 'O0': EffortLevel.O0, + 'O1': EffortLevel.O1, + 'O2': EffortLevel.O2, + 'O3': EffortLevel.O3, +} + + def bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. @@ -1727,6 +1751,37 @@ def _update_garbage_collection_guard(state, key, val): help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].' ) +optimization_level = enum_state( + name='jax_optimization_level', + enum_values=[ + 'UNKNOWN', + 'O0', + 'O1', + 'O2', + 'O3', + ], + default='UNKNOWN', + help='The degree to which the compiler should optimize for execution time', + include_in_jit_key=True +) + +memory_fitting_level = enum_state( + name='jax_memory_fitting_level', + enum_values=[ + 'UNKNOWN', + 'O0', + 'O1', + 'O2', + 'O3', + ], + default='UNKNOWN', + help=( + 'The degree to which the compiler should attempt to make the program' + ' fit in memory' + ), + include_in_jit_key=True +) + cpu_collectives_implementation = optional_enum_state( name='jax_cpu_collectives_implementation', enum_values=["gloo", "mpi", "megascale"], diff --git a/tests/api_test.py b/tests/api_test.py index 1687713ec441..536fbf372911 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -61,6 +61,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -1366,6 +1367,36 @@ def f(x): "exec_time_compilation_effort": 0.0, })(1.0) + def test_optimization_level_compiler_option(self): + def f(x): + return jnp.sqrt(x**2) + 1.0 + + if xla_extension_version < 316: + self.skipTest("Requires XLA extension version >= 316") + f_jit = jit( + f, + compiler_options={ + "optimization_level": config.EffortLevel.O1.value, + }, + )( + 1.0 + ) # doesn't crash. + + def test_memory_fitting_level_compiler_option(self): + def f(x): + return jnp.sqrt(x**2) + 1.0 + + if xla_extension_version < 316: + self.skipTest("Requires XLA extension version >= 316") + f_jit = jit( + f, + compiler_options={ + "memory_fitting_level": config.EffortLevel.O0.value, + }, + )( + 1.0 + ) # doesn't crash. + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1.