Skip to content

Commit

Permalink
Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
Browse files Browse the repository at this point in the history
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
  • Loading branch information
skye committed Feb 14, 2025
1 parent 60dcded commit b81d7af
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
details.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
Expand All @@ -37,6 +37,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

* Bug fixes
* Persistent compilation cache no longer writes access time file if
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
eviction policy isn't enabled. This should improve performance when using
the cache with large-scale network storage.

## jax 0.5.0 (Jan 17, 2025)

As of this release, JAX now uses
Expand Down
14 changes: 8 additions & 6 deletions jax/_src/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def get(self, key: str) -> bytes | None:
raise ValueError("key cannot be empty")

cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)
Expand All @@ -108,8 +107,10 @@ def get(self, key: str) -> bytes | None:

val = cache_path.read_bytes()

timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)
if self.eviction_enabled:
timestamp = time.time_ns().to_bytes(8, "little")
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
atime_path.write_bytes(timestamp)

return val

Expand Down Expand Up @@ -138,7 +139,6 @@ def put(self, key: str, val: bytes) -> None:
return

cache_path = self.path / f"{key}{_CACHE_SUFFIX}"
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"

if self.eviction_enabled:
self.lock.acquire(timeout=self.lock_timeout_secs)
Expand All @@ -151,8 +151,10 @@ def put(self, key: str, val: bytes) -> None:

cache_path.write_bytes(val)

timestamp = time.time_ns().to_bytes(8, "little")
atime_path.write_bytes(timestamp)
if self.eviction_enabled:
timestamp = time.time_ns().to_bytes(8, "little")
atime_path = self.path / f"{key}{_ATIME_SUFFIX}"
atime_path.write_bytes(timestamp)

finally:
if self.eviction_enabled:
Expand Down
13 changes: 13 additions & 0 deletions tests/lru_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import glob
import importlib.util
import tempfile
import time
Expand Down Expand Up @@ -153,6 +154,18 @@ def test_max_size(self):
self.assertIsNone(cache.get("a"))
self.assertEqual(set(self.path.glob(f"*{_CACHE_SUFFIX}")), set())

# Check that we don't write access time file when the eviction policy is
# disabled. Writing this file can be extremely unperformant and cause
# problems on large-scale network storage.
def test_no_atime_file(self):
cache = LRUCache(self.name, max_size=-1)

cache.put("a", b"a")
self.assertEmpty(glob.glob(self.name + "/*atime*"))

cache.get("a")
self.assertEmpty(glob.glob(self.name + "/*atime*"))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b81d7af

Please sign in to comment.