Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optimization_level and memory_fitting_level XLA compilation options. #26466

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -190,6 +191,13 @@ def get_compile_options(

build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value
if xla_extension_version >= 316:
build_options.optimization_level = config.EffortLevel(
config.optimization_level.value
).value
build_options.memory_fitting_level = config.EffortLevel(
config.memory_fitting_level.value
).value

# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
Expand All @@ -203,7 +211,12 @@ def get_compile_options(
if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
'exec_time_optimization_effort', 'memory_fitting_effort']
"exec_time_optimization_effort", "memory_fitting_effort"]
if xla_extension_version >= 316:
overrides_on_build_options.extend(
["optimization_level", "memory_fitting_level"]
)

env_options_overrides = dict(env_options_overrides)
for name in overrides_on_build_options:
if name in env_options_overrides:
Expand Down
55 changes: 55 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Iterator, Sequence
import contextlib
import enum
import functools
import itertools
import logging
Expand All @@ -35,6 +36,29 @@
_T = TypeVar('_T')


class EffortLevel(enum.Enum):
"""Effort level enum, mirroring the XLA effort options."""

UNKNOWN = 0
O0 = 9
O1 = 19
O2 = 29
O3 = 39

@classmethod
def _missing_(cls, value: object) -> EffortLevel | None:
return _effort_from_string.get(value)


_effort_from_string: dict[Any, EffortLevel] = {
'UNKNOWN': EffortLevel.UNKNOWN,
'O0': EffortLevel.O0,
'O1': EffortLevel.O1,
'O2': EffortLevel.O2,
'O3': EffortLevel.O3,
}


def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.

Expand Down Expand Up @@ -1727,6 +1751,37 @@ def _update_garbage_collection_guard(state, key, val):
help='Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].'
)

optimization_level = enum_state(
name='jax_optimization_level',
enum_values=[
'UNKNOWN',
'O0',
'O1',
'O2',
'O3',
],
default='UNKNOWN',
help='The degree to which the compiler should optimize for execution time',
include_in_jit_key=True
)

memory_fitting_level = enum_state(
name='jax_memory_fitting_level',
enum_values=[
'UNKNOWN',
'O0',
'O1',
'O2',
'O3',
],
default='UNKNOWN',
help=(
'The degree to which the compiler should attempt to make the program'
' fit in memory'
),
include_in_jit_key=True
)

cpu_collectives_implementation = optional_enum_state(
name='jax_cpu_collectives_implementation',
enum_values=["gloo", "mpi", "megascale"],
Expand Down
31 changes: 31 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
Expand Down Expand Up @@ -1366,6 +1367,36 @@ def f(x):
"exec_time_compilation_effort": 0.0,
})(1.0)

def test_optimization_level_compiler_option(self):
def f(x):
return jnp.sqrt(x**2) + 1.0

if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={
"optimization_level": config.EffortLevel.O1.value,
},
)(
1.0
) # doesn't crash.

def test_memory_fitting_level_compiler_option(self):
def f(x):
return jnp.sqrt(x**2) + 1.0

if xla_extension_version < 316:
self.skipTest("Requires XLA extension version >= 316")
f_jit = jit(
f,
compiler_options={
"memory_fitting_level": config.EffortLevel.O0.value,
},
)(
1.0
) # doesn't crash.

def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
Expand Down
Loading