Skip to content

Commit

Permalink
Create JAX wheel build target.
Browse files Browse the repository at this point in the history
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in #25126)

Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the JAX wheel with `python3 -m build` command.

Bazel command example for building nightly JAX wheel:

```
bazel build :jax_wheel \
  --config=ci_linux_x86_64 \
  --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
  --repo_env=ML_WHEEL_TYPE=custom \
  --repo_env=ML_WHEEL_BUILD_DATE=20250211 \
  --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
```

Resulting wheel:
```
bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl
```

PiperOrigin-RevId: 724102315
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
1 parent c6c38fb commit c091349
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 12 deletions.
95 changes: 95 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"@tsl//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
)
load("//jax:py_deps.bzl", "transitive_py_deps")
load(
"//jaxlib:jax.bzl",
"AARCH64_MANYLINUX_TAG",
"PPC64LE_MANYLINUX_TAG",
"X86_64_MANYLINUX_TAG",
"jax_wheel",
)

transitive_py_deps(
name = "transitive_py_deps",
deps = [
"//jax",
"//jax:experimental",
"//jax:experimental_colocated_python",
"//jax:experimental_sparse",
"//jax:lax_reference",
"//jax:pallas_gpu_ops",
"//jax:pallas_mosaic_gpu",
"//jax:pallas_tpu_ops",
"//jax:pallas_triton",
"//jax:source_mapper",
"//jax:tpu_custom_call",
"//jax/_src/lib",
"//jax/_src/pallas/mosaic_gpu",
"//jax/experimental/jax2tf",
"//jax/extend",
"//jax/extend:ifrt_programs",
"//jax/tools:jax_to_ir",
"//jax/tools:jax_to_ir_with_tensorflow",
],
)

py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
deps = [
"//jaxlib/tools:build_utils",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)

jax_wheel(
name = "jax_wheel",
no_abi = True,
no_platform = True,
source_files = [
":transitive_py_deps",
"//jax:py.typed",
"//jax:numpy/__init__.pyi",
"//jax:_src/basearray.pyi",
"//jax/_src/pallas/triton:__init__.py",
"//jax/_src/pallas/mosaic:__init__.py",
"//jax:experimental/compilation_cache/__init__.py",
"//jax/tools:__init__.py",
"AUTHORS",
"LICENSE",
"README.md",
"pyproject.toml",
"setup.py",
],
wheel_binary = ":build_wheel",
wheel_name = "jax",
)

verify_manylinux_compliance_test(
name = "jax_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
100 changes: 100 additions & 0 deletions build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script that builds a JAX wheel, intended to be run via bazel run as part
# of the JAX build process.

import argparse
import os
import pathlib
import shutil
import tempfile

from jaxlib.tools import build_utils

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--sources_path",
default=None,
help=(
"Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used."
),
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--jaxlib_git_hash",
default="",
required=True,
help="Git hash. Empty if unknown. Optional.",
)
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
args = parser.parse_args()


def copy_file(
src_file: str,
dst_dir: str,
) -> None:
"""Copy a file to the destination directory.
Args:
src_file: file to be copied
dst_dir: destination directory
"""

dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
os.makedirs(dest_dir_path, exist_ok=True)
shutil.copy(src_file, dest_dir_path)
os.chmod(os.path.join(dst_dir, src_file), 0o644)


def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
"""Rearrange source files in target the target directory.
Args:
deps: a list of paths to files.
srcs_dir: target directory where files are copied to.
"""

for file in deps:
if not (file.startswith("bazel-out") or file.startswith("external")):
copy_file(file, srcs_dir)


tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
sources_path = tmpdir.name

try:
os.makedirs(args.output_path, exist_ok=True)
prepare_srcs(args.srcs, pathlib.Path(sources_path))
build_utils.build_wheel(
sources_path,
args.output_path,
package_name="jax",
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:
tmpdir.cleanup()
4 changes: 4 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ config_setting(
exports_files([
"LICENSE",
"version.py",
"py.typed",
"numpy/__init__.pyi",
"_src/basearray.pyi",
"experimental/compilation_cache/__init__.py",
])

exports_files(
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ package(
],
)

exports_files(
["__init__.py"],
)

py_library(
name = "core",
srcs = ["core.py"],
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ package(
],
)

exports_files(
["__init__.py"],
)

pytype_strict_library(
name = "core",
srcs = ["core.py"],
Expand Down
40 changes: 40 additions & 0 deletions jax/py_deps.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Rule for collecting python files that a target depends on.
It traverses dependencies of provided targets, collect their direct and transitive python deps and
then return a list of paths to files.
"""

def _transitive_py_deps_impl(ctx):
outputs = depset(
[],
transitive = [dep[PyInfo].transitive_sources for dep in ctx.attr.deps],
)
return DefaultInfo(files = outputs)

_transitive_py_deps = rule(
attrs = {
"deps": attr.label_list(
allow_files = True,
providers = [PyInfo],
),
},
implementation = _transitive_py_deps_impl,
)

def transitive_py_deps(name, deps = []):
_transitive_py_deps(name = name + "_gather", deps = deps)
native.filegroup(name = name, srcs = [":" + name + "_gather"])
4 changes: 4 additions & 0 deletions jax/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ package(
default_visibility = ["//visibility:public"],
)

exports_files(
["__init__.py"],
)

py_library(
name = "jax_to_ir",
srcs = ["jax_to_ir.py"],
Expand Down
Loading

0 comments on commit c091349

Please sign in to comment.