Skip to content

Commit

Permalink
Create JAX wheel build target.
Browse files Browse the repository at this point in the history
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in #25126)

Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the JAX wheel with `python3 -m build` command.

Bazel command example for building nightly JAX wheel:

```
bazel build :jax_wheel \
  --config=ci_linux_x86_64 \
  --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
  --repo_env=ML_WHEEL_TYPE=custom \
  --repo_env=ML_WHEEL_BUILD_DATE=20250211 \
  --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
```

Resulting wheel:
```
bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl
```

PiperOrigin-RevId: 724102315
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 6addf02 commit 740adb5
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 9 deletions.
72 changes: 72 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@tsl//third_party/py:python_wheel.bzl", "transitive_py_deps")
load(
"//jaxlib:jax.bzl",
"jax_wheel",
)

transitive_py_deps(
name = "transitive_py_deps",
deps = [
"//jax",
"//jax:compilation_cache",
"//jax:experimental",
"//jax:experimental_colocated_python",
"//jax:experimental_sparse",
"//jax:lax_reference",
"//jax:pallas_gpu_ops",
"//jax:pallas_mosaic_gpu",
"//jax:pallas_tpu_ops",
"//jax:pallas_triton",
"//jax:source_mapper",
"//jax/_src/lib",
"//jax/_src/pallas/mosaic_gpu",
"//jax/experimental/jax2tf",
"//jax/extend",
"//jax/extend:ifrt_programs",
"//jax/tools:jax_to_ir",
],
)

py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
deps = [
"//jaxlib/tools:build_utils",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)

jax_wheel(
name = "jax_wheel",
no_abi = True,
no_platform = True,
source_files = [
":transitive_py_deps",
"//jax:py.typed",
"//jax:numpy/__init__.pyi",
"//jax:_src/basearray.pyi",
"AUTHORS",
"LICENSE",
"README.md",
"pyproject.toml",
"setup.py",
],
wheel_binary = ":build_wheel",
wheel_name = "jax",
)
100 changes: 100 additions & 0 deletions build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script that builds a JAX wheel, intended to be run via bazel run as part
# of the JAX build process.

import argparse
import os
import pathlib
import shutil
import tempfile

from jaxlib.tools import build_utils

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--sources_path",
default=None,
help=(
"Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used."
),
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--jaxlib_git_hash",
default="",
required=True,
help="Git hash. Empty if unknown. Optional.",
)
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
args = parser.parse_args()


def copy_file(
src_file: str,
dst_dir: str,
) -> None:
"""Copy a file to the destination directory.
Args:
src_file: file to be copied
dst_dir: destination directory
"""

dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
os.makedirs(dest_dir_path, exist_ok=True)
shutil.copy(src_file, dest_dir_path)
os.chmod(os.path.join(dst_dir, src_file), 0o644)


def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
"""Filter the sources and copy them to the destination directory.
Args:
deps: a list of paths to files.
srcs_dir: target directory where files are copied to.
"""

for file in deps:
if not (file.startswith("bazel-out") or file.startswith("external")):
copy_file(file, srcs_dir)


tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
sources_path = tmpdir.name

try:
os.makedirs(args.output_path, exist_ok=True)
prepare_srcs(args.srcs, pathlib.Path(sources_path))
build_utils.build_wheel(
sources_path,
args.output_path,
package_name="jax",
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:
tmpdir.cleanup()
4 changes: 4 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ config_setting(
exports_files([
"LICENSE",
"version.py",
"py.typed",
"numpy/__init__.pyi",
"_src/basearray.pyi",
])

exports_files(
Expand Down Expand Up @@ -1182,6 +1185,7 @@ pytype_library(
pytype_library(
name = "compilation_cache",
srcs = [
"experimental/compilation_cache/__init__.py",
"experimental/compilation_cache/compilation_cache.py",
],
visibility = ["//visibility:public"],
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ package(

py_library(
name = "core",
srcs = ["core.py"],
srcs = [
"__init__.py",
"core.py",
],
deps = [
"//jax",
"//jax/_src/pallas",
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ package(

pytype_strict_library(
name = "core",
srcs = ["core.py"],
srcs = [
"__init__.py",
"core.py",
],
deps = ["//jax/_src/pallas"],
)

Expand Down
5 changes: 4 additions & 1 deletion jax/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ package(

py_library(
name = "jax_to_ir",
srcs = ["jax_to_ir.py"],
srcs = [
"__init__.py",
"jax_to_ir.py",
],
tags = [
"ignore_for_dep=third_party.py.jax.experimental.jax2tf",
"ignore_for_dep=third_party.py.tensorflow",
Expand Down
38 changes: 32 additions & 6 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)

def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
def _get_full_wheel_name(package_name, no_abi, no_platform, platform_name, cpu_name, wheel_version):
if no_abi:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
Expand All @@ -332,7 +332,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
wheel_platform_tag = "any" if no_platform else "_".join(
PLATFORM_TAGS_DICT[platform_name, cpu_name],
),
)

def _jax_wheel_impl(ctx):
Expand Down Expand Up @@ -360,10 +362,13 @@ def _jax_wheel_impl(ctx):
env["JAX_RELEASE"] = "1"

cpu = ctx.attr.cpu
no_abi = ctx.attr.no_abi
no_platform = ctx.attr.no_platform
platform_name = ctx.attr.platform_name
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = ctx.attr.no_abi,
no_abi = no_abi,
no_platform = no_platform,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
Expand All @@ -373,7 +378,8 @@ def _jax_wheel_impl(ctx):
wheel_dir = output_file.path[:output_file.path.rfind("/")]

args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
if not no_platform:
args.add("--cpu", cpu)
args.add("--jaxlib_git_hash", git_hash) # required argument

if ctx.attr.enable_cuda:
Expand All @@ -389,11 +395,17 @@ def _jax_wheel_impl(ctx):
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")

srcs = []
for src in ctx.attr.source_files:
for f in src.files.to_list():
srcs.append(f)
args.add("--srcs=%s" % (f.path))

args.set_param_file_format("flag_per_line")
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [],
inputs = srcs,
outputs = [output_file],
executable = executable,
env = env,
Expand All @@ -411,9 +423,11 @@ _jax_wheel = rule(
),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"no_platform": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"source_files": attr.label_list(allow_files = True),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
Expand All @@ -427,7 +441,15 @@ _jax_wheel = rule(
executable = False,
)

def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
def jax_wheel(
name,
wheel_binary,
wheel_name,
no_abi = False,
no_platform = False,
enable_cuda = False,
platform_version = "",
source_files = []):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
Expand All @@ -437,8 +459,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
no_platform: whether to build a wheel without platform tag
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
Returns:
A directory containing the wheel
Expand All @@ -448,6 +472,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
no_platform = no_platform,
enable_cuda = enable_cuda,
platform_version = platform_version,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
Expand All @@ -465,6 +490,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:x86_64": "x86_64",
}),
source_files = source_files,
)

jax_test_file_visibility = []
Expand Down

0 comments on commit 740adb5

Please sign in to comment.