Skip to content

Commit

Permalink
Create JAX wheel build target.
Browse files Browse the repository at this point in the history
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in jax-ml/jax#25126)

Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the JAX wheel with `python3 -m build` command.

Bazel command example for building nightly JAX wheel:

```
bazel build :jax_wheel \
  --config=ci_linux_x86_64 \
  --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
  --repo_env=ML_WHEEL_TYPE=custom \
  --repo_env=ML_WHEEL_BUILD_DATE=20250211 \
  --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
```

Resulting wheel:
```
bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl
```

PiperOrigin-RevId: 724102315
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 6da089d commit 98b4d66
Showing 1 changed file with 62 additions and 33 deletions.
95 changes: 62 additions & 33 deletions third_party/tsl/third_party/py/python_wheel.bzl
Original file line number Diff line number Diff line change
@@ -1,36 +1,4 @@
"""
Python wheel repository rules.
The calculated wheel version suffix depends on the wheel type:
- nightly: .dev{build_date}
- release: ({custom_version_suffix})?
- custom: .dev{build_date}(+{git_hash})?({custom_version_suffix})?
- snapshot (default): -0
The following environment variables can be set:
{wheel_type}: ML_WHEEL_TYPE
{build_date}: ML_WHEEL_BUILD_DATE (should be YYYYMMDD or YYYY-MM-DD)
{git_hash}: ML_WHEEL_GIT_HASH
{custom_version_suffix}: ML_WHEEL_VERSION_SUFFIX
Examples:
1. nightly wheel version: 2.19.0.dev20250107
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=nightly
--repo_env=ML_WHEEL_BUILD_DATE=20250107
2. release wheel version: 2.19.0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
3. release candidate wheel version: 2.19.0-rc1
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
--repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1
4. custom wheel version: 2.19.0.dev20250107+cbe478fc5-custom
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=custom
--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD)
--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
--repo_env=ML_WHEEL_VERSION_SUFFIX=-custom
5. snapshot wheel version: 2.19.0-0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=snapshot
"""
""" Repository and build rules for Python wheels packaging utilities. """

def _get_host_environ(repository_ctx, name, default_value = None):
"""Returns the value of an environment variable on the host platform.
Expand Down Expand Up @@ -127,3 +95,64 @@ python_wheel_version_suffix_repository = repository_rule(
implementation = _python_wheel_version_suffix_repository_impl,
environ = _ENVIRONS,
)

""" Repository rule for storing Python wheel filename version suffix.
The calculated wheel version suffix depends on the wheel type:
- nightly: .dev{build_date}
- release: ({custom_version_suffix})?
- custom: .dev{build_date}(+{git_hash})?({custom_version_suffix})?
- snapshot (default): -0
The following environment variables can be set:
{wheel_type}: ML_WHEEL_TYPE
{build_date}: ML_WHEEL_BUILD_DATE (should be YYYYMMDD or YYYY-MM-DD)
{git_hash}: ML_WHEEL_GIT_HASH
{custom_version_suffix}: ML_WHEEL_VERSION_SUFFIX
Examples:
1. nightly wheel version: 2.19.0.dev20250107
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=nightly
--repo_env=ML_WHEEL_BUILD_DATE=20250107
2. release wheel version: 2.19.0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
3. release candidate wheel version: 2.19.0-rc1
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
--repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1
4. custom wheel version: 2.19.0.dev20250107+cbe478fc5-custom
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=custom
--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD)
--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
--repo_env=ML_WHEEL_VERSION_SUFFIX=-custom
5. snapshot wheel version: 2.19.0-0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=snapshot
""" # buildifier: disable=no-effect

def _transitive_py_deps_impl(ctx):
outputs = depset(
[],
transitive = [dep[PyInfo].transitive_sources for dep in ctx.attr.deps],
)

return DefaultInfo(files = outputs)

_transitive_py_deps = rule(
attrs = {
"deps": attr.label_list(
allow_files = True,
providers = [PyInfo],
),
},
implementation = _transitive_py_deps_impl,
)

def transitive_py_deps(name, deps = []):
_transitive_py_deps(name = name + "_gather", deps = deps)
native.filegroup(name = name, srcs = [":" + name + "_gather"])

"""Collects python files that a target depends on.
It traverses dependencies of provided targets, collect their direct and
transitive python deps and then return a list of paths to files.
""" # buildifier: disable=no-effect

0 comments on commit 98b4d66

Please sign in to comment.