Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. #25126

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
jax_python_wheel_repository(
name = "jax_wheel",
version_key = "_version",
version_source = "//jax:version.py",
)

load(
"@tsl//third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
Expand Down
6 changes: 0 additions & 6 deletions jax/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_strict_library",
)

licenses(["notice"])
Expand Down Expand Up @@ -46,8 +45,3 @@ py_library(
"//jax/experimental/jax2tf",
] + py_deps("tensorflow_core"),
)

pytype_strict_library(
name = "build_utils",
srcs = ["build_utils.py"],
)
15 changes: 12 additions & 3 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down Expand Up @@ -71,16 +73,23 @@ def _get_version_for_build() -> str:
"""Determine the version at build time.

The returned version string depends on which environment variables are set:
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
wheel build rule.
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
return _version
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down
109 changes: 89 additions & 20 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""Bazel macros used by the JAX build."""

load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
Expand Down Expand Up @@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = []
jax_test_util_visibility = []
loops_visibility = []

PLATFORM_TAGS_DICT = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}

# TODO(vam): remove this once zstandard builds against Python 3.13
def get_zstandard():
if HERMETIC_PYTHON_VERSION == "3.13":
Expand Down Expand Up @@ -268,7 +280,7 @@ def jax_multiplatform_test(
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
Expand Down Expand Up @@ -309,15 +321,60 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)

def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
if no_abi:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
)

def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
executable = ctx.executable.wheel_binary

output = ctx.actions.declare_directory(ctx.label.name)
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
" If you absolutely need to build links directly against the CUDA libraries, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")

env = {}
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument

full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if BUILD_TAG:
env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
env["JAX_RELEASE"] = "1"

cpu = ctx.attr.cpu
platform_name = ctx.attr.platform_name
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = ctx.attr.no_abi,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]

args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
args.add("--jaxlib_git_hash", git_hash) # required argument

if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
Expand All @@ -336,11 +393,13 @@ def _jax_wheel_impl(ctx):
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
inputs = [],
outputs = [output_file],
executable = executable,
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]

return [DefaultInfo(files = depset(direct = [output_file]))]

_jax_wheel = rule(
attrs = {
Expand All @@ -350,26 +409,34 @@ _jax_wheel = rule(
# b/365588895 Investigate cfg = "exec" for multi platform builds
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
},
implementation = _jax_wheel_impl,
executable = False,
)

def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
"""Create jax artifact wheels.

Common artifact attributes are grouped within a single macro.

Args:
name: the name of the wheel
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel

Expand All @@ -379,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
_jax_wheel(
name = name,
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.
platform_name = select({
"@platforms//os:osx": "Darwin",
"@platforms//os:macos": "Darwin",
"@platforms//os:windows": "Windows",
"@platforms//os:linux": "Linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
Expand Down
43 changes: 43 additions & 0 deletions jaxlib/jax_python_wheel.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Repository rule to generate a file with JAX wheel version. """

def _jax_python_wheel_repository_impl(repository_ctx):
version_source = repository_ctx.attr.version_source
version_key = repository_ctx.attr.version_key

version_file_content = repository_ctx.read(
repository_ctx.path(version_source),
)
version_start_index = version_file_content.find(version_key)
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")

wheel_version = version_file_content[version_start_index:version_end_index].replace(
version_key,
"WHEEL_VERSION",
)
repository_ctx.file(
"wheel.bzl",
wheel_version,
)
repository_ctx.file("BUILD", "")

jax_python_wheel_repository = repository_rule(
implementation = _jax_python_wheel_repository_impl,
attrs = {
"version_source": attr.label(mandatory = True, allow_single_file = True),
"version_key": attr.string(mandatory = True),
},
)
Loading
Loading