-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in #25126) Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file). You can still build the JAX wheel with `python3 -m build` command. Bazel command example for building nightly JAX wheel: ``` bazel build :jax_wheel \ --config=ci_linux_x86_64 \ --repo_env=HERMETIC_PYTHON_VERSION=3.10 \ --repo_env=ML_WHEEL_TYPE=custom \ --repo_env=ML_WHEEL_BUILD_DATE=20250211 \ --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD) ``` Resulting wheel: ``` bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl ``` PiperOrigin-RevId: 724102315
- Loading branch information
1 parent
6addf02
commit d587a16
Showing
7 changed files
with
220 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
load("@tsl//third_party/py:python_wheel.bzl", "transitive_py_deps") | ||
load( | ||
"//jaxlib:jax.bzl", | ||
"jax_wheel", | ||
) | ||
|
||
transitive_py_deps( | ||
name = "transitive_py_deps", | ||
deps = [ | ||
"//jax", | ||
"//jax:compilation_cache", | ||
"//jax:experimental", | ||
"//jax:experimental_colocated_python", | ||
"//jax:experimental_sparse", | ||
"//jax:lax_reference", | ||
"//jax:pallas_gpu_ops", | ||
"//jax:pallas_mosaic_gpu", | ||
"//jax:pallas_tpu_ops", | ||
"//jax:pallas_triton", | ||
"//jax:source_mapper", | ||
"//jax/_src/lib", | ||
"//jax/_src/pallas/mosaic_gpu", | ||
"//jax/experimental/jax2tf", | ||
"//jax/extend", | ||
"//jax/extend:ifrt_programs", | ||
"//jax/tools:jax_to_ir", | ||
], | ||
) | ||
|
||
py_binary( | ||
name = "build_wheel", | ||
srcs = ["build_wheel.py"], | ||
deps = [ | ||
"//jaxlib/tools:build_utils", | ||
"@pypi_build//:pkg", | ||
"@pypi_setuptools//:pkg", | ||
"@pypi_wheel//:pkg", | ||
], | ||
) | ||
|
||
jax_wheel( | ||
name = "jax_wheel", | ||
no_abi = True, | ||
no_platform = True, | ||
source_files = [ | ||
":transitive_py_deps", | ||
"//jax:py.typed", | ||
"//jax:numpy/__init__.pyi", | ||
"//jax:_src/basearray.pyi", | ||
"AUTHORS", | ||
"LICENSE", | ||
"README.md", | ||
"pyproject.toml", | ||
"setup.py", | ||
], | ||
wheel_binary = ":build_wheel", | ||
wheel_name = "jax", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Script that builds a JAX wheel, intended to be run via bazel run as part | ||
# of the JAX build process. | ||
|
||
import argparse | ||
import os | ||
import pathlib | ||
import shutil | ||
import tempfile | ||
|
||
from jaxlib.tools import build_utils | ||
|
||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@") | ||
parser.add_argument( | ||
"--sources_path", | ||
default=None, | ||
help=( | ||
"Path in which the wheel's sources should be prepared. Optional. If " | ||
"omitted, a temporary directory will be used." | ||
), | ||
) | ||
parser.add_argument( | ||
"--output_path", | ||
default=None, | ||
required=True, | ||
help="Path to which the output wheel should be written. Required.", | ||
) | ||
parser.add_argument( | ||
"--jaxlib_git_hash", | ||
default="", | ||
required=True, | ||
help="Git hash. Empty if unknown. Optional.", | ||
) | ||
parser.add_argument( | ||
"--srcs", help="source files for the wheel", action="append" | ||
) | ||
args = parser.parse_args() | ||
|
||
|
||
def copy_file( | ||
src_file: str, | ||
dst_dir: str, | ||
) -> None: | ||
"""Copy a file to the destination directory. | ||
Args: | ||
src_file: file to be copied | ||
dst_dir: destination directory | ||
""" | ||
|
||
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file)) | ||
os.makedirs(dest_dir_path, exist_ok=True) | ||
shutil.copy(src_file, dest_dir_path) | ||
os.chmod(os.path.join(dst_dir, src_file), 0o644) | ||
|
||
|
||
def prepare_srcs(deps: list[str], srcs_dir: str) -> None: | ||
"""Filter the sources and copy them to the destination directory. | ||
Args: | ||
deps: a list of paths to files. | ||
srcs_dir: target directory where files are copied to. | ||
""" | ||
|
||
for file in deps: | ||
if not (file.startswith("bazel-out") or file.startswith("external")): | ||
copy_file(file, srcs_dir) | ||
|
||
|
||
tmpdir = None | ||
sources_path = args.sources_path | ||
if sources_path is None: | ||
tmpdir = tempfile.TemporaryDirectory(prefix="jax") | ||
sources_path = tmpdir.name | ||
|
||
try: | ||
os.makedirs(args.output_path, exist_ok=True) | ||
prepare_srcs(args.srcs, pathlib.Path(sources_path)) | ||
build_utils.build_wheel( | ||
sources_path, | ||
args.output_path, | ||
package_name="jax", | ||
git_hash=args.jaxlib_git_hash, | ||
) | ||
finally: | ||
if tmpdir: | ||
tmpdir.cleanup() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters