Skip to content

Commit

Permalink
Create JAX wheel build target.
Browse files Browse the repository at this point in the history
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for jaxlib, CUDA and PJRT plugins in jax-ml/jax#25126)

Previously JAX wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the JAX wheel with `python3 -m build` command.

Bazel command example for building nightly JAX wheel:

```
bazel build :jax_wheel \
  --config=ci_linux_x86_64 \
  --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
  --repo_env=ML_WHEEL_TYPE=custom \
  --repo_env=ML_WHEEL_BUILD_DATE=20250211 \
  --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
```

Resulting wheel:
```
bazel-bin/dist/jax-0.5.1.dev20250211+d4f1f2278-py3-none-any.whl
```

PiperOrigin-RevId: 724102315
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Feb 14, 2025
1 parent e422cba commit 8caf13a
Showing 1 changed file with 123 additions and 33 deletions.
156 changes: 123 additions & 33 deletions third_party/py/python_wheel.bzl
Original file line number Diff line number Diff line change
@@ -1,36 +1,4 @@
"""
Python wheel repository rules.
The calculated wheel version suffix depends on the wheel type:
- nightly: .dev{build_date}
- release: ({custom_version_suffix})?
- custom: .dev{build_date}(+{git_hash})?({custom_version_suffix})?
- snapshot (default): -0
The following environment variables can be set:
{wheel_type}: ML_WHEEL_TYPE
{build_date}: ML_WHEEL_BUILD_DATE (should be YYYYMMDD or YYYY-MM-DD)
{git_hash}: ML_WHEEL_GIT_HASH
{custom_version_suffix}: ML_WHEEL_VERSION_SUFFIX
Examples:
1. nightly wheel version: 2.19.0.dev20250107
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=nightly
--repo_env=ML_WHEEL_BUILD_DATE=20250107
2. release wheel version: 2.19.0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
3. release candidate wheel version: 2.19.0-rc1
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
--repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1
4. custom wheel version: 2.19.0.dev20250107+cbe478fc5-custom
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=custom
--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD)
--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
--repo_env=ML_WHEEL_VERSION_SUFFIX=-custom
5. snapshot wheel version: 2.19.0-0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=snapshot
"""
""" Repository and build rules for Python wheels packaging utilities. """

def _get_host_environ(repository_ctx, name, default_value = None):
"""Returns the value of an environment variable on the host platform.
Expand Down Expand Up @@ -127,3 +95,125 @@ python_wheel_version_suffix_repository = repository_rule(
implementation = _python_wheel_version_suffix_repository_impl,
environ = _ENVIRONS,
)

""" Repository rule for storing Python wheel filename version suffix.
The calculated wheel version suffix depends on the wheel type:
- nightly: .dev{build_date}
- release: ({custom_version_suffix})?
- custom: .dev{build_date}(+{git_hash})?({custom_version_suffix})?
- snapshot (default): -0
The following environment variables can be set:
{wheel_type}: ML_WHEEL_TYPE
{build_date}: ML_WHEEL_BUILD_DATE (should be YYYYMMDD or YYYY-MM-DD)
{git_hash}: ML_WHEEL_GIT_HASH
{custom_version_suffix}: ML_WHEEL_VERSION_SUFFIX
Examples:
1. nightly wheel version: 2.19.0.dev20250107
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=nightly
--repo_env=ML_WHEEL_BUILD_DATE=20250107
2. release wheel version: 2.19.0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
3. release candidate wheel version: 2.19.0-rc1
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=release
--repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1
4. custom wheel version: 2.19.0.dev20250107+cbe478fc5-custom
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=custom
--repo_env=ML_WHEEL_BUILD_DATE=$(git show -s --format=%as HEAD)
--repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)
--repo_env=ML_WHEEL_VERSION_SUFFIX=-custom
5. snapshot wheel version: 2.19.0-0
Env vars passed to Bazel command: --repo_env=ML_WHEEL_TYPE=snapshot
""" # buildifier: disable=no-effect

def _transitive_py_deps_impl(ctx):
outputs = depset(
[],
transitive = [dep[PyInfo].transitive_sources for dep in ctx.attr.deps],
)

return DefaultInfo(files = outputs)

_transitive_py_deps = rule(
attrs = {
"deps": attr.label_list(
allow_files = True,
providers = [PyInfo],
),
},
implementation = _transitive_py_deps_impl,
)

def transitive_py_deps(name, deps = []):
_transitive_py_deps(name = name + "_gather", deps = deps)
native.filegroup(name = name, srcs = [":" + name + "_gather"])

"""Collects python files that a target depends on.
It traverses dependencies of provided targets, collect their direct and
transitive python deps and then return a list of paths to files.
""" # buildifier: disable=no-effect

FilePathInfo = provider(
"Returns path of selected files.",
fields = {
"files": "requested files from data attribute",
},
)

def _collect_data_aspect_impl(_, ctx):
files = {}
extensions = ctx.attr._extensions
if hasattr(ctx.rule.attr, "data"):
for data in ctx.rule.attr.data:
for f in data.files.to_list():
if not any([f.path.endswith(ext) for ext in extensions]):
continue
if "pypi" in f.path:
continue
files[f] = True

if hasattr(ctx.rule.attr, "deps"):
for dep in ctx.rule.attr.deps:
if dep[FilePathInfo].files:
for file in dep[FilePathInfo].files.to_list():
files[file] = True

return [FilePathInfo(files = depset(files.keys()))]

collect_data_aspect = aspect(
implementation = _collect_data_aspect_impl,
attr_aspects = ["deps"],
attrs = {
"_extensions": attr.string_list(
default = [".so", ".pyd", ".pyi", ".dll", ".dylib", ".lib", ".pd"],
),
},
)

def _collect_data_files_impl(ctx):
files = []
for dep in ctx.attr.deps:
files.extend((dep[FilePathInfo].files.to_list()))
return [DefaultInfo(files = depset(
files,
))]

collect_data_files = rule(
implementation = _collect_data_files_impl,
attrs = {
"deps": attr.label_list(
aspects = [collect_data_aspect],
),
},
)

"""Rule to collect data files.
It recursively traverses `deps` attribute of the target and collects paths to
files that are in `data` attribute. Then it filters all files that do not match
the provided extensions.
""" # buildifier: disable=no-effect

0 comments on commit 8caf13a

Please sign in to comment.