Skip to content

Commit

Permalink
Create JAX wheel build target.
Browse files Browse the repository at this point in the history
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in #25126)

Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the JAX wheel with `python3 -m build` command.

Bazel command example for building nightly JAX wheel:

```
bazel build :jax_wheel \
  --config=ci_linux_x86_64 \
  --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
  --repo_env=ML_WHEEL_TYPE=custom \
  --repo_env=ML_WHEEL_BUILD_DATE=20250211 \
  --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
```

Resulting wheel:
```
bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl
```

PiperOrigin-RevId: 724102315
  • Loading branch information
Google-ML-Automation committed Feb 13, 2025
1 parent 876668f commit 2eff3e1
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 16 deletions.
93 changes: 93 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"@tsl//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
)
load("//jax:py_deps.bzl", "transitive_py_deps")
load(
"//jaxlib:jax.bzl",
"AARCH64_MANYLINUX_TAG",
"PPC64LE_MANYLINUX_TAG",
"X86_64_MANYLINUX_TAG",
"jax_wheel",
)

transitive_py_deps(
name = "transitive_py_deps",
deps = [
"//jax",
"//jax:experimental",
"//jax:experimental_libs",
"//jax:lax_reference",
"//jax:pallas_mosaic_gpu",
"//jax:tpu_custom_call",
"//jax/_src/lib",
"//jax/_src/pallas/mosaic:libs",
"//jax/_src/pallas/mosaic_gpu",
"//jax/_src/pallas/triton:libs",
"//jax/experimental/jax2tf",
"//jax/experimental/jax2tf:jax2tf_libs",
"//jax/experimental/pallas/ops/gpu:libs",
"//jax/extend",
"//jax/extend:ifrt_programs",
"//jax/extend/mlir:libs",
"//jax/extend/mlir/dialects:libs",
"//jax/tools:libs",
],
)

py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
deps = [
"//jaxlib/tools:build_utils",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)

jax_wheel(
name = "jax_wheel",
no_abi = True,
no_platform = True,
source_files = [
":transitive_py_deps",
"//jax:py.typed",
"//jax:numpy/__init__.pyi",
"//jax:_src/basearray.pyi",
"//jax:test_util_sources",
"//jax:internal_test_util_sources",
"AUTHORS",
"LICENSE",
"README.md",
"pyproject.toml",
"setup.py",
],
wheel_binary = ":build_wheel",
wheel_name = "jax",
)

verify_manylinux_compliance_test(
name = "jax_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
100 changes: 100 additions & 0 deletions build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script that builds a JAX wheel, intended to be run via bazel run as part
# of the JAX build process.

import argparse
import os
import pathlib
import shutil
import tempfile

from jaxlib.tools import build_utils

parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--sources_path",
default=None,
help=(
"Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used."
),
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--jaxlib_git_hash",
default="",
required=True,
help="Git hash. Empty if unknown. Optional.",
)
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
args = parser.parse_args()


def copy_file(
src_file: str,
dst_dir: str,
) -> None:
"""Copy a file to the destination directory.
Args:
src_file: file to be copied
dst_dir: destination directory
"""

dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
os.makedirs(dest_dir_path, exist_ok=True)
shutil.copy(src_file, dest_dir_path)
os.chmod(os.path.join(dst_dir, src_file), 0o644)


def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
"""Rearrange source files in target the target directory.
Args:
deps: a list of paths to files.
srcs_dir: target directory where files are copied to.
"""

for file in deps:
if not (file.startswith("bazel-out") or file.startswith("external")):
copy_file(file, srcs_dir)


tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
sources_path = tmpdir.name

try:
os.makedirs(args.output_path, exist_ok=True)
prepare_srcs(args.srcs, pathlib.Path(sources_path))
build_utils.build_wheel(
sources_path,
args.output_path,
package_name="jax",
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:
tmpdir.cleanup()
39 changes: 35 additions & 4 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ config_setting(
exports_files([
"LICENSE",
"version.py",
"py.typed",
"numpy/__init__.pyi",
"_src/basearray.pyi",
])

exports_files(
Expand Down Expand Up @@ -110,6 +113,37 @@ package_group(
packages = mosaic_gpu_internal_users,
)

py_library(
name = "experimental_libs",
srcs = glob([
"experimental/**/*.py",
"experimental/pallas/ops/gpu/**/*.py",
]),
)

filegroup(
name = "test_util_sources",
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
],
)

filegroup(
name = "internal_test_util_sources",
srcs = [
"_src/internal_test_util/__init__.py",
"_src/internal_test_util/deprecation_module.py",
"_src/internal_test_util/export_back_compat_test_util.py",
"_src/internal_test_util/lax_test_util.py",
"_src/internal_test_util/test_harnesses.py",
] + glob(
[
"_src/internal_test_util/lazy_loader_module/*.py",
],
),
)

# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
Expand All @@ -118,10 +152,7 @@ py_library(
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
testonly = 1,
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
],
srcs = [":test_util_sources"],
visibility = [
":internal",
] + jax_test_util_visibility,
Expand Down
17 changes: 17 additions & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ package(
],
)

py_library(
name = "libs",
srcs = ["__init__.py"],
deps = [
":core",
":error_handling",
":helpers",
":interpret",
":lowering",
":pallas_call_registration",
":pipeline",
":primitives",
":random",
":verification",
],
)

py_library(
name = "core",
srcs = ["core.py"],
Expand Down
11 changes: 11 additions & 0 deletions jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ package(
],
)

pytype_strict_library(
name = "libs",
srcs = ["__init__.py"],
deps = [
":core",
":lowering",
":pallas_call_registration",
":primitives",
],
)

pytype_strict_library(
name = "core",
srcs = ["core.py"],
Expand Down
9 changes: 9 additions & 0 deletions jax/experimental/jax2tf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ package(
default_visibility = ["//visibility:private"],
)

py_library(
name = "jax2tf_libs",
srcs = glob([
"examples/**/*.py",
"tests/**/*.py",
]),
visibility = ["//visibility:public"],
)

py_library(
name = "jax2tf",
srcs = ["__init__.py"],
Expand Down
5 changes: 5 additions & 0 deletions jax/experimental/pallas/ops/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ exports_files(
srcs = glob(["*.py"]),
)

py_library(
name = "libs",
srcs = glob(["*.py"]),
)

filegroup(
name = "triton_ops",
srcs = glob(
Expand Down
9 changes: 9 additions & 0 deletions jax/extend/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ package(
default_visibility = ["//jax:jax_extend_users"],
)

pytype_strict_library(
name = "libs",
srcs = ["__init__.py"],
deps = [
":ir",
":pass_manager",
],
)

pytype_strict_library(
name = "ir",
srcs = ["ir.py"],
Expand Down
18 changes: 18 additions & 0 deletions jax/extend/mlir/dialects/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@ package(
default_visibility = ["//jax:jax_extend_users"],
)

pytype_strict_library(
name = "libs",
srcs = ["__init__.py"],
deps = [
":arithmetic_dialect",
":builtin_dialect",
":chlo_dialect",
":func_dialect",
":math_dialect",
":memref_dialect",
":scf_dialect",
":sdy_dialect",
":sparse_tensor_dialect",
":stablehlo_dialect",
":vector_dialect",
],
)

pytype_strict_library(
name = "arithmetic_dialect",
srcs = ["arith.py"],
Expand Down
Loading

0 comments on commit 2eff3e1

Please sign in to comment.