Skip to content

Commit

Permalink
Refactor JAX build wheel rule and add py_import targets.
Browse files Browse the repository at this point in the history
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The current setup is designed for postsubmit only, it consists of running two commands for producing the wheels (`bazel build` and `bazel run`), then launching docker, installing the wheels in venv, and then running bazel tests with disabled `build_jaxlib` flag.

The new JAX wheel build rule produces the wheel in the Build phase using `bazel build` command only. That means that the JAX wheel targets can be added as dependencies in other targets in Build phase.

This is a pre-requisite for running bazel tests with disabled `build_jaxlib` flag using one command only, without the need to build the wheels separately.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The version of the wheel in the build rule output depends on the rule attribute values.

3. Flag combinations for creating wheels with different versions:
  * snapshot: default build rule behavior (`--wheel_type=snapshot`)
  * release: `--wheel_type=release`
  * nightly build with date as version suffix: `--wheel_type=nightly --build_date=<YYYYmmdd>`
  * build with git data as version suffix: `--build_date=$(git show -s --format=%as HEAD) --git_hash=$(git rev-parse HEAD)`
  * build with git data and additional custom version suffix: `--build_date=$(git show -s --format=%as HEAD) --git_hash=$(git rev-parse HEAD) --custom_version_suffix=<custom suffix>`

PiperOrigin-RevId: 699315679
  • Loading branch information
Google-ML-Automation committed Jan 17, 2025
1 parent 318764b commit 677cd8e
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 93 deletions.
5 changes: 3 additions & 2 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
# This configuration is used for building the wheels.
build:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
Expand Down
15 changes: 15 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()

load("//jaxlib:jax_python_wheel_version.bzl", "jax_python_wheel_version_repository")
jax_python_wheel_version_repository(
name = "jax_wheel_version",
file_with_version = "//jax:version.py",
version_key = "_version",
)

load(
"@tsl//third_party/py:python_wheel_version_suffix.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)

load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
Expand Down
1 change: 1 addition & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ async def main():

if "cuda" in args.wheels:
wheel_build_command_base.append("--config=cuda")
wheel_build_command_base.append("--config=cuda_wheel")
if args.use_clang:
wheel_build_command_base.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
Expand Down
5 changes: 0 additions & 5 deletions jax/tools/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ def platform_tag(cpu: str) -> str:
}[(platform.system(), cpu)]
return f"{platform_name}_{cpu_name}"

def get_githash(jaxlib_git_hash):
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
with open(jaxlib_git_hash, "r") as f:
return f.readline().strip()
return jaxlib_git_hash

def build_wheel(
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
Expand Down
16 changes: 13 additions & 3 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_BUILD_TAG"):
return _version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down Expand Up @@ -77,10 +81,16 @@ def _get_version_for_build() -> str:
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if (
os.getenv("JAX_RELEASE")
or os.getenv("JAXLIB_RELEASE")
or os.getenv("WHEEL_BUILD_TAG")
):
return _version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX")
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)


Expand Down
128 changes: 108 additions & 20 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""Bazel macros used by the JAX build."""

load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
Expand Down Expand Up @@ -267,7 +270,7 @@ def jax_multiplatform_test(
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
Expand Down Expand Up @@ -308,15 +311,91 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)

def _get_wheel_platform_name(platform_name, cpu_name):
platform = ""
cpu = ""
if platform_name == "linux":
platform = "manylinux2014"
cpu = cpu_name
elif platform_name == "macosx":
if cpu_name == "arm64":
cpu = "arm64"
platform = "macosx_11_0"
else:
cpu = "x86_64"
platform = "macosx_10_14"
elif platform_name == "win":
platform = "win"
cpu = "amd64"
return "{platform}_{cpu}".format(
platform = platform,
cpu = cpu,
)

def _get_cpu(platform_name, platform_tag):
# Following the convention in jax/tools/build_utils.py.
if platform_name == "macosx" and platform_tag == "arm64":
return "arm64"
if platform_name == "win" and platform_tag == "x86_64":
return "AMD64"
return "aarch64" if platform_tag == "arm64" else platform_tag

def _get_full_wheel_name(rule_name, platform_name, cpu_name, major_cuda_version, wheel_version):
if "pjrt" in rule_name:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
package_name = rule_name.replace("_wheel", "").replace(
"cuda",
"cuda{}".format(major_cuda_version),
)
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name),
)

def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
executable = ctx.executable.wheel_binary

output = ctx.actions.declare_directory(ctx.label.name)
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built with CUDA dependencies." +
" Please provide `--config=cuda_wheel` for bazel build command." +
" If you absolutely need to add CUDA dependencies, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")

env = {}
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument

full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if BUILD_TAG:
env["WHEEL_BUILD_TAG"] = BUILD_TAG
args.add("--build-tag", BUILD_TAG)
full_wheel_version += "-{}".format(BUILD_TAG)

cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag)
wheel_name = _get_full_wheel_name(
ctx.label.name,
ctx.attr.platform_name,
cpu,
ctx.attr.platform_version,
full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]

args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
args.add("--jaxlib_git_hash", "\"{}\"".format(git_hash)) # required argument

if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
Expand All @@ -335,11 +414,13 @@ def _jax_wheel_impl(ctx):
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
inputs = [],
outputs = [output_file],
executable = executable,
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]

return [DefaultInfo(files = depset(direct = [output_file]))]

_jax_wheel = rule(
attrs = {
Expand All @@ -350,12 +431,16 @@ _jax_wheel = rule(
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:git_hash")),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
},
implementation = _jax_wheel_impl,
executable = False,
Expand All @@ -380,21 +465,24 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
wheel_binary = wheel_binary,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
platform_name = select({
"@platforms//os:osx": "macosx",
"@platforms//os:macos": "macosx",
"@platforms//os:windows": "win",
"@platforms//os:linux": "linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:aarch64": "arm64",
"@platforms//cpu:arm64": "arm64",
"@platforms//cpu:x86_64": "x86_64",
}),
manylinux_compliance_tag = select({
"@platforms//cpu:aarch64": "manylinux_2_17_aarch64",
"@platforms//cpu:arm64": "manylinux_2_17_aarch64",
"@platforms//cpu:x86_64": "manylinux_2_17_x86_64",
"//conditions:default": "",
}),
)

jax_test_file_visibility = []
Expand Down
38 changes: 38 additions & 0 deletions jaxlib/jax_python_wheel_version.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Repository rule to generate a file with JAX python wheel version. """

def _jax_python_wheel_version_repository_impl(repository_ctx):
file_content = repository_ctx.read(
repository_ctx.path(repository_ctx.attr.file_with_version),
)
version_line_start_index = file_content.find(repository_ctx.attr.version_key)
version_line_end_index = version_line_start_index + file_content[version_line_start_index:].find("\n")
repository_ctx.file(
"wheel_version.bzl",
file_content[version_line_start_index:version_line_end_index].replace(
repository_ctx.attr.version_key,
"WHEEL_VERSION",
),
)
repository_ctx.file("BUILD", "")

jax_python_wheel_version_repository = repository_rule(
implementation = _jax_python_wheel_version_repository_impl,
attrs = {
"file_with_version": attr.label(mandatory = True, allow_single_file = True),
"version_key": attr.string(mandatory = True),
},
)
Loading

0 comments on commit 677cd8e

Please sign in to comment.