From 05ade1bfdd5d38d82872a5bb73af64ba3c4b2a38 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 22 Nov 2024 16:07:42 -0800 Subject: [PATCH] Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 699315679 --- WORKSPACE | 17 ++++ jax/tools/build_utils.py | 19 ++-- jax/version.py | 11 +- jaxlib/jax.bzl | 128 ++++++++++++++++++++---- jaxlib/jax_python_wheel.bzl | 65 ++++++++++++ jaxlib/tools/BUILD.bazel | 85 +++++++++------- jaxlib/tools/build_gpu_kernels_wheel.py | 3 +- jaxlib/tools/build_gpu_plugin_wheel.py | 3 +- jaxlib/tools/build_wheel.py | 8 +- tests/version_test.py | 13 +++ 10 files changed, 275 insertions(+), 77 deletions(-) create mode 100644 jaxlib/jax_python_wheel.bzl diff --git a/WORKSPACE b/WORKSPACE index 130c9f804c93..72bd6851eb57 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -62,6 +62,23 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") +jax_python_wheel_repository( + name = "jax_wheel", + manylinux_compliance_tags_dict_key = "MANYLINUX_COMPLIANCE_TAGS_ALIAS_DICT", + manylinux_compliance_tags_dict_source = "//jax/tools:build_utils.py", + version_key = "_version", + version_source = "//jax:version.py", +) + +load( + "@tsl//third_party/py:python_wheel.bzl", + "python_wheel_version_suffix_repository", +) +python_wheel_version_suffix_repository( + name = "jax_wheel_version_suffix", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 83d0b4b25923..51d490e34b98 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -25,6 +25,12 @@ import glob from collections.abc import Sequence +MANYLINUX_COMPLIANCE_TAGS_ALIAS_DICT = { + "x86_64": "manylinux2014_x86_64", + "aarch64": "manylinux2014_aarch64", + "ppc64le": "manylinux2014_ppc64le", +} + def is_windows() -> bool: return sys.platform.startswith("win32") @@ -52,21 +58,16 @@ def copy_file( def platform_tag(cpu: str) -> str: + system_platform = platform.system() + if system_platform == "Linux": + return MANYLINUX_COMPLIANCE_TAGS_ALIAS_DICT[cpu] platform_name, cpu_name = { - ("Linux", "x86_64"): ("manylinux2014", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), - ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), ("Darwin", "arm64"): ("macosx_11_0", "arm64"), ("Windows", "AMD64"): ("win", "amd64"), - }[(platform.system(), cpu)] + }[(system_platform, cpu)] return f"{platform_name}_{cpu_name}" -def get_githash(jaxlib_git_hash): - if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): - with open(jaxlib_git_hash, "r") as f: - return f.readline().strip() - return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jax/version.py b/jax/version.py index 484cd96acf41..be7d76170a6f 100644 --- a/jax/version.py +++ b/jax/version.py @@ -35,6 +35,8 @@ def _get_version_string() -> str: # In this case we return it directly. if _release_version is not None: return _release_version + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") return _version_from_git_tree(_version) or _version_from_todays_date(_version) @@ -71,16 +73,19 @@ def _get_version_for_build() -> str: """Determine the version at build time. The returned version string depends on which environment variables are set: + - if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc" - if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16" - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906" - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc """ if _release_version is not None: return _release_version - if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): - return _version_from_todays_date(_version) - if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'): + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") + if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"): return _version + if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"): + return _version_from_todays_date(_version) return _version_from_git_tree(_version) or _version_from_todays_date(_version) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 49062c7283fd..7074ae419efd 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,7 +14,10 @@ """Bazel macros used by the JAX build.""" +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") +load("@jax_wheel//:wheel.bzl", "MANYLINUX_COMPLIANCE_TAGS_DICT", "WHEEL_VERSION") +load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") @@ -267,7 +270,7 @@ def jax_multiplatform_test( ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): - test_tags += ["manual"] + test_tags.append("manual") if backend == "gpu": test_tags += tf_cuda_tests_tags() native.py_test( @@ -308,15 +311,87 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _get_wheel_platform_name(platform_name, cpu_name): + if platform_name == "linux": + return MANYLINUX_COMPLIANCE_TAGS_DICT[cpu_name] + platform = "" + cpu = "" + if platform_name == "macosx": + if cpu_name == "arm64": + cpu = "arm64" + platform = "macosx_11_0" + else: + cpu = "x86_64" + platform = "macosx_10_14" + elif platform_name == "win": + platform = "win" + cpu = "amd64" + return "{platform}_{cpu}".format( + platform = platform, + cpu = cpu, + ) + +def _get_cpu(platform_name, platform_tag): + # Following the convention in jax/tools/build_utils.py. + if platform_name == "macosx" and platform_tag == "arm64": + return "arm64" + if platform_name == "win" and platform_tag == "x86_64": + return "AMD64" + return "aarch64" if platform_tag == "arm64" else platform_tag + +def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version): + if no_abi: + wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" + else: + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + return wheel_name_template.format( + package_name = package_name, + python_version = python_version, + major_python_version = python_version[0], + wheel_version = wheel_version, + wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name), + ) + def _jax_wheel_impl(ctx): + include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + output_path = ctx.attr.output_path[BuildSettingInfo].value + git_hash = ctx.attr.git_hash[BuildSettingInfo].value executable = ctx.executable.wheel_binary - output = ctx.actions.declare_directory(ctx.label.name) + if include_cuda_libs and not override_include_cuda_libs: + fail("JAX wheel shouldn't be built directly against the CUDA libraries." + + " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + + " If you absolutely need to build links directly against the CUDA libraries, provide" + + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + + env = {} args = ctx.actions.args() - args.add("--output_path", output.path) # required argument - args.add("--cpu", ctx.attr.platform_tag) # required argument - jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path - args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX) + env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX + if BUILD_TAG: + env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + if not WHEEL_VERSION_SUFFIX and not BUILD_TAG: + env["JAX_RELEASE"] = "1" + + cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag) + wheel_name = _get_full_wheel_name( + package_name = ctx.attr.wheel_name, + no_abi = ctx.attr.no_abi, + platform_name = ctx.attr.platform_name, + cpu_name = cpu, + wheel_version = full_wheel_version, + ) + output_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = output_file.path[:output_file.path.rfind("/")] + + args.add("--output_path", wheel_dir) # required argument + args.add("--cpu", cpu) # required argument + args.add("--jaxlib_git_hash", git_hash) # required argument if ctx.attr.enable_cuda: args.add("--enable-cuda", "True") @@ -335,11 +410,13 @@ def _jax_wheel_impl(ctx): args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], - outputs = [output], + inputs = [], + outputs = [output_file], executable = executable, + env = env, ) - return [DefaultInfo(files = depset(direct = [output]))] + + return [DefaultInfo(files = depset(direct = [output_file]))] _jax_wheel = rule( attrs = { @@ -349,19 +426,25 @@ _jax_wheel = rule( # b/365588895 Investigate cfg = "exec" for multi platform builds cfg = "target", ), + "wheel_name": attr.string(mandatory = True), + "no_abi": attr.bool(default = False), "platform_tag": attr.string(mandatory = True), - "git_hash": attr.label(allow_single_file = True), + "platform_name": attr.string(mandatory = True), + "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), + "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. "platform_version": attr.string(mandatory = True, default = ""), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), + "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), + "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), }, implementation = _jax_wheel_impl, executable = False, ) -def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): +def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -369,6 +452,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): Args: name: the name of the wheel wheel_binary: the binary to use to build the wheel + wheel_name: the name of the wheel + no_abi: whether to build a wheel without ABI enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel @@ -378,21 +463,22 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): _jax_wheel( name = name, wheel_binary = wheel_binary, + wheel_name = wheel_name, + no_abi = no_abi, enable_cuda = enable_cuda, platform_version = platform_version, - # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to - # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to - # the git hash file needs to be created first. - git_hash = select({ - "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", - "//conditions:default": None, + # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` + # flag in bazel command to pass the git hash for nightly or release builds. + platform_name = select({ + "@platforms//os:osx": "macosx", + "@platforms//os:macos": "macosx", + "@platforms//os:windows": "win", + "@platforms//os:linux": "linux", }), - # Following the convention in jax/tools/build_utils.py. # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. platform_tag = select({ - "//jaxlib/tools:macos_arm64": "arm64", - "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:aarch64": "arm64", + "@platforms//cpu:arm64": "arm64", "@platforms//cpu:x86_64": "x86_64", }), ) diff --git a/jaxlib/jax_python_wheel.bzl b/jaxlib/jax_python_wheel.bzl new file mode 100644 index 000000000000..1f8aef7e85b8 --- /dev/null +++ b/jaxlib/jax_python_wheel.bzl @@ -0,0 +1,65 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Repository rule to generate a file with JAX wheel version and manylinux compliance tags. """ + +def _jax_python_wheel_repository_impl(repository_ctx): + version_source = repository_ctx.attr.version_source + version_key = repository_ctx.attr.version_key + manylinux_compliance_tags_dict_source = repository_ctx.attr.manylinux_compliance_tags_dict_source + manylinux_compliance_tags_dict_key = repository_ctx.attr.manylinux_compliance_tags_dict_key + + version_file_content = repository_ctx.read( + repository_ctx.path(version_source), + ) + version_start_index = version_file_content.find(version_key) + version_end_index = version_start_index + version_file_content[version_start_index:].find("\n") + + tags_file_content = repository_ctx.read( + repository_ctx.path(manylinux_compliance_tags_dict_source), + ) + tags_dict_start_index = tags_file_content.find(manylinux_compliance_tags_dict_key) + tags_dict_end_index = (tags_dict_start_index + + tags_file_content[tags_dict_start_index:].find("}") + 1) + + wheel_version = version_file_content[version_start_index:version_end_index].replace( + version_key, + "WHEEL_VERSION", + ) + tags_dict = tags_file_content[tags_dict_start_index:tags_dict_end_index].replace( + manylinux_compliance_tags_dict_key, + "MANYLINUX_COMPLIANCE_TAGS_DICT", + ) + wheel_bzl_content = "{version}\n{tags_dict}".format( + version = wheel_version, + tags_dict = tags_dict, + ) + repository_ctx.file( + "wheel.bzl", + wheel_bzl_content, + ) + repository_ctx.file("BUILD", "") + +jax_python_wheel_repository = repository_rule( + implementation = _jax_python_wheel_repository_impl, + attrs = { + "version_source": attr.label(mandatory = True, allow_single_file = True), + "version_key": attr.string(mandatory = True), + "manylinux_compliance_tags_dict_source": attr.label( + mandatory = True, + allow_single_file = True, + ), + "manylinux_compliance_tags_dict_key": attr.string(mandatory = True), + }, +) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 63f2643fe230..f3f8ca78be5e 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,10 +14,14 @@ # JAX is Autograd and XLA -load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") +load("@jax_wheel//:wheel.bzl", "MANYLINUX_COMPLIANCE_TAGS_DICT") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -136,67 +140,72 @@ py_binary( ], ) -selects.config_setting_group( - name = "macos", - match_any = [ - "@platforms//os:osx", - "@platforms//os:macos", - ], -) - -selects.config_setting_group( - name = "arm64", - match_any = [ - "@platforms//cpu:aarch64", - "@platforms//cpu:arm64", - ], -) - -selects.config_setting_group( - name = "macos_arm64", - match_all = [ - ":arm64", - ":macos", - ], -) - -selects.config_setting_group( - name = "win_amd64", - match_all = [ - "@platforms//cpu:x86_64", - "@platforms//os:windows", - ], -) - string_flag( name = "jaxlib_git_hash", build_setting_default = "", ) -config_setting( - name = "jaxlib_git_hash_nightly_or_release", - flag_values = { - ":jaxlib_git_hash": "nightly", - }, +string_flag( + name = "output_path", + build_setting_default = "dist", ) jax_wheel( name = "jaxlib_wheel", + no_abi = False, wheel_binary = ":build_wheel", + wheel_name = "jaxlib", ) jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, + no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_cuda12_plugin", ) jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, + no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_cuda12_pjrt", +) + +verify_manylinux_compliance_test( + name = "jaxlib_manylinux_compliance_test", + aarch64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["aarch64"], + ppc64le_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["ppc64le"], + test_tags = [ + "manual", + ], + wheel = ":jaxlib_wheel", + x86_64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["x86_64"], +) + +verify_manylinux_compliance_test( + name = "jax_cuda_plugin_manylinux_compliance_test", + aarch64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["aarch64"], + ppc64le_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["ppc64le"], + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_plugin_wheel", + x86_64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["x86_64"], +) + +verify_manylinux_compliance_test( + name = "jax_cuda_pjrt_manylinux_compliance_test", + aarch64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["aarch64"], + ppc64le_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["ppc64le"], + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_pjrt_wheel", + x86_64_compliance_tag = MANYLINUX_COMPLIANCE_TAGS_DICT["x86_64"], ) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 65412f0365dc..4c4c1e8aada6 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -174,12 +174,11 @@ def prepare_wheel_rocm( if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 08c2389c292a..0e2bba0c74d0 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -167,12 +167,11 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: if tmpdir: diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4b71bd5de2d8..7005122acdfa 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -387,8 +387,12 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/tests/version_test.py b/tests/version_test.py index 51297a9716b1..1036d958fc4e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -104,6 +104,7 @@ def testBuildVersionInRelease(self): self.assertEqual(version, "1.2.3.dev4567") self.assertValidVersion(version) + @jtu.thread_unsafe_test() # Setting environment variables is not thread-safe. @patch_jax_version("1.2.3", None) def testBuildVersionFromEnvironment(self): # This test covers build-time construction of version strings in the @@ -157,6 +158,18 @@ def testBuildVersionFromEnvironment(self): self.assertTrue(version.endswith("test")) self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY="1", + WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3")