From 04bfda854e153742e07b24282e31cb9caaa80ce4 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 22 Nov 2024 16:07:42 -0800 Subject: [PATCH] Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 699315679 --- WORKSPACE | 15 ++++ jax/tools/BUILD | 6 -- jax/version.py | 15 +++- jaxlib/jax.bzl | 109 +++++++++++++++++++----- jaxlib/jax_python_wheel.bzl | 43 ++++++++++ jaxlib/tools/BUILD.bazel | 87 +++++++++++++++++-- jaxlib/tools/build_gpu_kernels_wheel.py | 5 +- jaxlib/tools/build_gpu_plugin_wheel.py | 5 +- {jax => jaxlib}/tools/build_utils.py | 17 +--- jaxlib/tools/build_wheel.py | 10 ++- tests/version_test.py | 13 +++ 11 files changed, 265 insertions(+), 60 deletions(-) create mode 100644 jaxlib/jax_python_wheel.bzl rename {jax => jaxlib}/tools/build_utils.py (86%) diff --git a/WORKSPACE b/WORKSPACE index 130c9f804c93..8c4f49ecffee 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -62,6 +62,21 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository") +jax_python_wheel_repository( + name = "jax_wheel", + version_key = "_version", + version_source = "//jax:version.py", +) + +load( + "@tsl//third_party/py:python_wheel.bzl", + "python_wheel_version_suffix_repository", +) +python_wheel_version_suffix_repository( + name = "jax_wheel_version_suffix", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 80f757ca421c..3e0a950292a3 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "py_deps", - "pytype_strict_library", ) licenses(["notice"]) @@ -46,8 +45,3 @@ py_library( "//jax/experimental/jax2tf", ] + py_deps("tensorflow_core"), ) - -pytype_strict_library( - name = "build_utils", - srcs = ["build_utils.py"], -) diff --git a/jax/version.py b/jax/version.py index 484cd96acf41..4c8d1798de05 100644 --- a/jax/version.py +++ b/jax/version.py @@ -35,6 +35,8 @@ def _get_version_string() -> str: # In this case we return it directly. if _release_version is not None: return _release_version + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") return _version_from_git_tree(_version) or _version_from_todays_date(_version) @@ -71,16 +73,23 @@ def _get_version_for_build() -> str: """Determine the version at build time. The returned version string depends on which environment variables are set: + - if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc" + Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc". + Please note that the WHEEL_VERSION_SUFFIX value is not the same as the + JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel + wheel build rule. - if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16" - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906" - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc """ if _release_version is not None: return _release_version - if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): - return _version_from_todays_date(_version) - if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'): + if os.getenv("WHEEL_VERSION_SUFFIX"): + return _version + os.getenv("WHEEL_VERSION_SUFFIX", "") + if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"): return _version + if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"): + return _version_from_todays_date(_version) return _version_from_git_tree(_version) or _version_from_todays_date(_version) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e85a43883899..394f9caefbff 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,7 +14,10 @@ """Bazel macros used by the JAX build.""" +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") +load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") +load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") @@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = [] jax_test_util_visibility = [] loops_visibility = [] +PLATFORM_TAGS_DICT = { + ("Linux", "x86_64"): ("manylinux2014", "x86_64"), + ("Linux", "aarch64"): ("manylinux2014", "aarch64"), + ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), + ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), + ("Darwin", "arm64"): ("macosx_11_0", "arm64"), + ("Windows", "AMD64"): ("win", "amd64"), +} + # TODO(vam): remove this once zstandard builds against Python 3.13 def get_zstandard(): if HERMETIC_PYTHON_VERSION == "3.13": @@ -268,7 +280,7 @@ def jax_multiplatform_test( ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): - test_tags += ["manual"] + test_tags.append("manual") if backend == "gpu": test_tags += tf_cuda_tests_tags() native.py_test( @@ -309,15 +321,60 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version): + if no_abi: + wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" + else: + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + return wheel_name_template.format( + package_name = package_name, + python_version = python_version, + major_python_version = python_version[0], + wheel_version = wheel_version, + wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]), + ) + def _jax_wheel_impl(ctx): + include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + output_path = ctx.attr.output_path[BuildSettingInfo].value + git_hash = ctx.attr.git_hash[BuildSettingInfo].value executable = ctx.executable.wheel_binary - output = ctx.actions.declare_directory(ctx.label.name) + if include_cuda_libs and not override_include_cuda_libs: + fail("JAX wheel shouldn't be built directly against the CUDA libraries." + + " Please provide `--config=cuda_libraries_from_stubs` for bazel build command." + + " If you absolutely need to build links directly against the CUDA libraries, provide" + + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + + env = {} args = ctx.actions.args() - args.add("--output_path", output.path) # required argument - args.add("--cpu", ctx.attr.platform_tag) # required argument - jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path - args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + + full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX) + env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX + if BUILD_TAG: + env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + if not WHEEL_VERSION_SUFFIX and not BUILD_TAG: + env["JAX_RELEASE"] = "1" + + cpu = ctx.attr.cpu + platform_name = ctx.attr.platform_name + wheel_name = _get_full_wheel_name( + package_name = ctx.attr.wheel_name, + no_abi = ctx.attr.no_abi, + platform_name = platform_name, + cpu_name = cpu, + wheel_version = full_wheel_version, + ) + output_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = output_file.path[:output_file.path.rfind("/")] + + args.add("--output_path", wheel_dir) # required argument + args.add("--cpu", cpu) # required argument + args.add("--jaxlib_git_hash", git_hash) # required argument if ctx.attr.enable_cuda: args.add("--enable-cuda", "True") @@ -336,11 +393,13 @@ def _jax_wheel_impl(ctx): args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], - outputs = [output], + inputs = [], + outputs = [output_file], executable = executable, + env = env, ) - return [DefaultInfo(files = depset(direct = [output]))] + + return [DefaultInfo(files = depset(direct = [output_file]))] _jax_wheel = rule( attrs = { @@ -350,19 +409,25 @@ _jax_wheel = rule( # b/365588895 Investigate cfg = "exec" for multi platform builds cfg = "target", ), - "platform_tag": attr.string(mandatory = True), - "git_hash": attr.label(allow_single_file = True), + "wheel_name": attr.string(mandatory = True), + "no_abi": attr.bool(default = False), + "cpu": attr.string(mandatory = True), + "platform_name": attr.string(mandatory = True), + "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), + "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. "platform_version": attr.string(mandatory = True, default = ""), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), + "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), + "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), }, implementation = _jax_wheel_impl, executable = False, ) -def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): +def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -370,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): Args: name: the name of the wheel wheel_binary: the binary to use to build the wheel + wheel_name: the name of the wheel + no_abi: whether to build a wheel without ABI enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel @@ -379,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): _jax_wheel( name = name, wheel_binary = wheel_binary, + wheel_name = wheel_name, + no_abi = no_abi, enable_cuda = enable_cuda, platform_version = platform_version, - # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to - # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to - # the git hash file needs to be created first. - git_hash = select({ - "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", - "//conditions:default": None, + # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` + # flag in bazel command to pass the git hash for nightly or release builds. + platform_name = select({ + "@platforms//os:osx": "Darwin", + "@platforms//os:macos": "Darwin", + "@platforms//os:windows": "Windows", + "@platforms//os:linux": "Linux", }), - # Following the convention in jax/tools/build_utils.py. # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. - platform_tag = select({ + cpu = select({ "//jaxlib/tools:macos_arm64": "arm64", "//jaxlib/tools:win_amd64": "AMD64", "//jaxlib/tools:arm64": "aarch64", diff --git a/jaxlib/jax_python_wheel.bzl b/jaxlib/jax_python_wheel.bzl new file mode 100644 index 000000000000..d5b5444fef69 --- /dev/null +++ b/jaxlib/jax_python_wheel.bzl @@ -0,0 +1,43 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Repository rule to generate a file with JAX wheel version. """ + +def _jax_python_wheel_repository_impl(repository_ctx): + version_source = repository_ctx.attr.version_source + version_key = repository_ctx.attr.version_key + + version_file_content = repository_ctx.read( + repository_ctx.path(version_source), + ) + version_start_index = version_file_content.find(version_key) + version_end_index = version_start_index + version_file_content[version_start_index:].find("\n") + + wheel_version = version_file_content[version_start_index:version_end_index].replace( + version_key, + "WHEEL_VERSION", + ) + repository_ctx.file( + "wheel.bzl", + wheel_version, + ) + repository_ctx.file("BUILD", "") + +jax_python_wheel_repository = repository_rule( + implementation = _jax_python_wheel_repository_impl, + attrs = { + "version_source": attr.label(mandatory = True, allow_single_file = True), + "version_key": attr.string(mandatory = True), + }, +) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 63f2643fe230..3188463817c8 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -18,12 +18,38 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") +load( + "@tsl//third_party/py:py_manylinux_compliance_test.bzl", + "verify_manylinux_compliance_test", +) +load( + "//jaxlib:jax.bzl", + "PLATFORM_TAGS_DICT", + "if_windows", + "jax_py_test", + "jax_wheel", + "pytype_strict_library", +) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +genrule( + name = "platform_tags_py", + srcs = [], + outs = ["platform_tags.py"], + cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT, +) + +pytype_strict_library( + name = "build_utils", + srcs = [ + "build_utils.py", + ":platform_tags_py", + ], +) + py_binary( name = "build_wheel", srcs = ["build_wheel.py"], @@ -41,7 +67,7 @@ py_binary( "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -99,7 +125,7 @@ py_binary( "//jax_plugins/rocm:__init__.py", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -128,7 +154,7 @@ py_binary( "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ - "//jax/tools:build_utils", + ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", @@ -173,30 +199,73 @@ string_flag( build_setting_default = "", ) -config_setting( - name = "jaxlib_git_hash_nightly_or_release", - flag_values = { - ":jaxlib_git_hash": "nightly", - }, +string_flag( + name = "output_path", + build_setting_default = "dist", ) jax_wheel( name = "jaxlib_wheel", + no_abi = False, wheel_binary = ":build_wheel", + wheel_name = "jaxlib", ) jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, + no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_cuda12_plugin", ) jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, + no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_cuda12_pjrt", +) + +AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) + +PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) + +X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")]) + +verify_manylinux_compliance_test( + name = "jaxlib_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jaxlib_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) + +verify_manylinux_compliance_test( + name = "jax_cuda_plugin_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_plugin_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, +) + +verify_manylinux_compliance_test( + name = "jax_cuda_pjrt_manylinux_compliance_test", + aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, + ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, + test_tags = [ + "manual", + ], + wheel = ":jax_cuda_pjrt_wheel", + x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 65412f0365dc..09a55d3c3352 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -24,7 +24,7 @@ import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -174,12 +174,11 @@ def prepare_wheel_rocm( if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: tmpdir.cleanup() diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 08c2389c292a..667807b51197 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -24,7 +24,7 @@ import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -167,12 +167,11 @@ def prepare_rocm_plugin_wheel(sources_path: pathlib.Path, *, cpu, rocm_version): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) build_utils.build_wheel( sources_path, args.output_path, package_name, - git_hash=git_hash, + git_hash=args.jaxlib_git_hash, ) finally: if tmpdir: diff --git a/jax/tools/build_utils.py b/jaxlib/tools/build_utils.py similarity index 86% rename from jax/tools/build_utils.py rename to jaxlib/tools/build_utils.py index 83d0b4b25923..0db7c7072ab2 100644 --- a/jax/tools/build_utils.py +++ b/jaxlib/tools/build_utils.py @@ -24,6 +24,7 @@ import subprocess import glob from collections.abc import Sequence +from jaxlib.tools import platform_tags def is_windows() -> bool: @@ -52,21 +53,11 @@ def copy_file( def platform_tag(cpu: str) -> str: - platform_name, cpu_name = { - ("Linux", "x86_64"): ("manylinux2014", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), - ("Linux", "ppc64le"): ("manylinux2014", "ppc64le"), - ("Darwin", "x86_64"): ("macosx_10_14", "x86_64"), - ("Darwin", "arm64"): ("macosx_11_0", "arm64"), - ("Windows", "AMD64"): ("win", "amd64"), - }[(platform.system(), cpu)] + platform_name, cpu_name = platform_tags.PLATFORM_TAGS_DICT[ + (platform.system(), cpu) + ] return f"{platform_name}_{cpu_name}" -def get_githash(jaxlib_git_hash): - if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): - with open(jaxlib_git_hash, "r") as f: - return f.readline().strip() - return jaxlib_git_hash def build_wheel( sources_path: str, output_path: str, package_name: str, git_hash: str = "" diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4b71bd5de2d8..2f4afae5431f 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -27,7 +27,7 @@ import tempfile from bazel_tools.tools.python.runfiles import runfiles -from jax.tools import build_utils +from jaxlib.tools import build_utils parser = argparse.ArgumentParser() parser.add_argument( @@ -387,8 +387,12 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu): if args.editable: build_utils.build_editable(sources_path, args.output_path, package_name) else: - git_hash = build_utils.get_githash(args.jaxlib_git_hash) - build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name, + git_hash=args.jaxlib_git_hash, + ) finally: if tmpdir: tmpdir.cleanup() diff --git a/tests/version_test.py b/tests/version_test.py index 51297a9716b1..1036d958fc4e 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -104,6 +104,7 @@ def testBuildVersionInRelease(self): self.assertEqual(version, "1.2.3.dev4567") self.assertValidVersion(version) + @jtu.thread_unsafe_test() # Setting environment variables is not thread-safe. @patch_jax_version("1.2.3", None) def testBuildVersionFromEnvironment(self): # This test covers build-time construction of version strings in the @@ -157,6 +158,18 @@ def testBuildVersionFromEnvironment(self): self.assertTrue(version.endswith("test")) self.assertValidVersion(version) + with jtu.set_env( + JAX_RELEASE=None, + JAXLIB_RELEASE=None, + JAX_NIGHTLY=None, + JAXLIB_NIGHTLY="1", + WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1", + ): + with assert_no_subprocess_call(): + version = jax.version._get_version_for_build() + self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1") + self.assertValidVersion(version) + def testVersions(self): check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3", minimum_jaxlib_version="1.2.3")