diff --git a/.bazelrc b/.bazelrc index 6ef7d4493937..84d1bf4a45ba 100644 --- a/.bazelrc +++ b/.bazelrc @@ -118,9 +118,10 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true +# This configuration is used for building the wheels. +build:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to diff --git a/WORKSPACE b/WORKSPACE index ed284acadf81..247d3d4aaabe 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -61,6 +61,13 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("@tsl//third_party/py:python_wheel_version.bzl", "python_wheel_version_repository") +python_wheel_version_repository( + name = "jax_wheel_version", + file_with_version = "//jax:version.py", + version_key = "_version", +) + load( "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", "cuda_json_init_repository", diff --git a/build/build.py b/build/build.py index 25a873d89e24..f01536b78452 100755 --- a/build/build.py +++ b/build/build.py @@ -503,6 +503,7 @@ async def main(): if "cuda" in wheel: wheel_build_command.append("--config=cuda") + wheel_build_command.append("--config=cuda_wheel") wheel_build_command.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" ) diff --git a/jax/tools/build_utils.py b/jax/tools/build_utils.py index 83d0b4b25923..ae74b9985ad6 100644 --- a/jax/tools/build_utils.py +++ b/jax/tools/build_utils.py @@ -62,6 +62,11 @@ def platform_tag(cpu: str) -> str: }[(platform.system(), cpu)] return f"{platform_name}_{cpu_name}" + +def build_tag() -> str: + return os.getenv("WHEEL_BUILD_TAG") + + def get_githash(jaxlib_git_hash): if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash): with open(jaxlib_git_hash, "r") as f: diff --git a/jax/version.py b/jax/version.py index 3e8a8291ec8d..aee0eabeb414 100644 --- a/jax/version.py +++ b/jax/version.py @@ -35,6 +35,10 @@ def _get_version_string() -> str: # In this case we return it directly. if _release_version is not None: return _release_version + if os.environ.get("WHEEL_BUILD_TAG"): + return _version + if os.environ.get("WHEEL_VERSION_SUFFIX"): + return _version + os.environ.get("WHEEL_VERSION_SUFFIX") return _version_from_git_tree(_version) or _version_from_todays_date(_version) @@ -77,10 +81,16 @@ def _get_version_for_build() -> str: """ if _release_version is not None: return _release_version - if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'): - return _version_from_todays_date(_version) - if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'): + if ( + os.environ.get("JAX_RELEASE") + or os.environ.get("JAXLIB_RELEASE") + or os.environ.get("WHEEL_BUILD_TAG") + ): return _version + if os.environ.get("WHEEL_VERSION_SUFFIX"): + return _version + os.environ.get("WHEEL_VERSION_SUFFIX") + if os.environ.get("JAX_NIGHTLY") or os.environ.get("JAXLIB_NIGHTLY"): + return _version_from_todays_date(_version) return _version_from_git_tree(_version) or _version_from_todays_date(_version) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 976e5f26cb4b..8e3aa50fce63 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -14,7 +14,9 @@ """Bazel macros used by the JAX build.""" +load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library") +load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") @@ -308,38 +310,163 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) +def _get_wheel_platform_name(platform_name, cpu_name): + platform = "" + cpu = "" + if platform_name == "linux": + platform = "manylinux2014" + cpu = cpu_name + elif platform_name == "macosx": + if cpu_name == "arm64": + cpu = "arm64" + platform = "macosx_11_0" + else: + cpu = "x86_64" + platform = "macosx_10_14" + elif platform_name == "win": + platform = "win" + cpu = "amd64" + return "{platform}_{cpu}".format( + platform = platform, + cpu = cpu, + ) + +def _get_cpu(platform_name, platform_tag): + # Following the convention in jax/tools/build_utils.py. + if platform_name == "macosx" and platform_tag == "arm64": + return "arm64" + if platform_name == "win" and platform_tag == "x86_64": + return "AMD64" + return "aarch64" if platform_tag == "arm64" else platform_tag + +def _get_full_wheel_name(rule_name, platform_name, cpu_name, major_cuda_version, wheel_version): + if "pjrt" in rule_name: + wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" + else: + wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + package_name = rule_name.replace("_wheel", "").replace( + "cuda", + "cuda{}".format(major_cuda_version), + ) + return wheel_name_template.format( + package_name = package_name, + python_version = python_version, + major_python_version = python_version[0], + wheel_version = wheel_version, + wheel_platform_tag = _get_wheel_platform_name(platform_name, cpu_name), + ) + def _jax_wheel_impl(ctx): + include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value + override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value + wheel_type = ctx.attr.wheel_type[BuildSettingInfo].value + git_hash = ctx.attr.git_hash[BuildSettingInfo].value + custom_version_suffix = ctx.attr.custom_version_suffix[BuildSettingInfo].value + build_date = ctx.attr.build_date[BuildSettingInfo].value + output_path = ctx.attr.output_path[BuildSettingInfo].value + verify_manylinux = ctx.attr.verify_manylinux[BuildSettingInfo].value + executable = ctx.executable.wheel_binary + full_wheel_version = ctx.attr.wheel_version + + if include_cuda_libs and not override_include_cuda_libs: + fail("JAX wheel shouldn't be built with CUDA dependencies." + + " Please provide `--config=cuda_wheel` for bazel build command." + + " If you absolutely need to add CUDA dependencies, provide" + + " `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") + + env = {} + if wheel_type == "nightly": + if not build_date: + fail("--//jaxlib/tools:build_date is required for nightly builds!") + env["JAXLIB_NIGHTLY"] = "1" + formatted_date = build_date.replace("-", "") + env["WHEEL_VERSION_SUFFIX"] = ".dev{}".format(formatted_date) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + elif wheel_type == "release": + env["JAXLIB_RELEASE"] = "1" + elif build_date: + formatted_date = build_date.replace("-", "") + formatted_hash = git_hash[:9] + if git_hash: + env["WHEEL_VERSION_SUFFIX"] = ".dev{date}+{hash}".format( + date = formatted_date, + hash = formatted_hash, + ) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + else: + env["WHEEL_VERSION_SUFFIX"] = ".dev{}".format(formatted_date) + full_wheel_version += env["WHEEL_VERSION_SUFFIX"] + if custom_version_suffix: + env["WHEEL_VERSION_SUFFIX"] = "{version_suffix}.{custom_version_suffix}".format( + version_suffix = env["WHEEL_VERSION_SUFFIX"], + custom_version_suffix = custom_version_suffix, + ) + full_wheel_version += ".{}".format(custom_version_suffix) + else: + env["WHEEL_BUILD_TAG"] = "0" + full_wheel_version += "-0" + + cpu = _get_cpu(ctx.attr.platform_name, ctx.attr.platform_tag) + wheel_name = _get_full_wheel_name( + ctx.label.name, + ctx.attr.platform_name, + cpu, + ctx.attr.platform_version, + full_wheel_version, + ) + output_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = output_file.path[:output_file.path.rfind("/")] - output = ctx.actions.declare_directory(ctx.label.name) - args = ctx.actions.args() - args.add("--output_path", output.path) # required argument - args.add("--cpu", ctx.attr.platform_tag) # required argument - jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path - args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument + args = { + "--output_path": wheel_dir, # required argument + "--cpu": cpu, # required argument + } + args["--jaxlib_git_hash"] = "\"{}\"".format(git_hash) # required argument if ctx.attr.enable_cuda: - args.add("--enable-cuda", "True") + args["--enable-cuda"] = "True" if ctx.attr.platform_version == "": fail("platform_version must be set to a valid cuda version for cuda wheels") - args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + args["--platform_version"] = ctx.attr.platform_version # required for gpu wheels if ctx.attr.enable_rocm: - args.add("--enable-rocm", "True") + args["--enable-rocm"] = "True" if ctx.attr.platform_version == "": fail("platform_version must be set to a valid rocm version for rocm wheels") - args.add("--platform_version", ctx.attr.platform_version) # required for gpu wheels + args["--platform_version"] = ctx.attr.platform_version # required for gpu wheels if ctx.attr.skip_gpu_kernels: - args.add("--skip_gpu_kernels") - - args.set_param_file_format("flag_per_line") - args.use_param_file("@%s", use_always = False) - ctx.actions.run( - arguments = [args], - inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [], - outputs = [output], - executable = executable, + args["--skip_gpu_kernels"] = "True" + args_as_string = "" + for arg in args: + args_as_string += "{arg} {val} ".format(arg = arg, val = args[arg]) + + ctx.actions.run_shell( + inputs = [], + command = executable.path + " " + args_as_string, + outputs = [output_file], + tools = [executable], + env = env, ) - return [DefaultInfo(files = depset(direct = [output]))] + + auditwheel_show_log = None + if ctx.attr.platform_name == "linux": + auditwheel_show_log = ctx.actions.declare_file("auditwheel_show.log") + args = ctx.actions.args() + args.add("--wheel_path", output_file.path) + if verify_manylinux: + args.add("--compliance-tag", ctx.attr.manylinux_compliance_tag) + args.add("--auditwheel-show-log-path", auditwheel_show_log.path) + ctx.actions.run( + arguments = [args], + inputs = [output_file], + outputs = [auditwheel_show_log], + executable = ctx.executable.verify_manylinux_compliance_binary, + ) + + auditwheel_show_output = [auditwheel_show_log] if auditwheel_show_log else [] + return [DefaultInfo(files = depset(direct = [output_file] + auditwheel_show_output))] _jax_wheel = rule( attrs = { @@ -350,12 +477,27 @@ _jax_wheel = rule( cfg = "target", ), "platform_tag": attr.string(mandatory = True), - "git_hash": attr.label(allow_single_file = True), + "platform_name": attr.string(mandatory = True), + "git_hash": attr.label(default = Label("//jaxlib/tools:git_hash")), + "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), + "wheel_type": attr.label(default = Label("//jaxlib/tools:wheel_type")), + "custom_version_suffix": attr.label(default = Label("//jaxlib/tools:custom_version_suffix")), + "build_date": attr.label(default = Label("//jaxlib/tools:build_date")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. "platform_version": attr.string(mandatory = True, default = ""), "skip_gpu_kernels": attr.bool(default = False), "enable_rocm": attr.bool(default = False), + "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), + "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "wheel_version": attr.string(default = WHEEL_VERSION), + "verify_manylinux_compliance_binary": attr.label( + default = Label("@tsl//third_party/py:verify_manylinux_compliance"), + executable = True, + cfg = "exec", + ), + "verify_manylinux": attr.label(default = Label("@tsl//third_party/py:verify_manylinux")), + "manylinux_compliance_tag": attr.string(mandatory = True), }, implementation = _jax_wheel_impl, executable = False, @@ -380,21 +522,24 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""): wheel_binary = wheel_binary, enable_cuda = enable_cuda, platform_version = platform_version, - # Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to - # pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to - # the git hash file needs to be created first. - git_hash = select({ - "//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink", - "//conditions:default": None, + platform_name = select({ + "@platforms//os:osx": "macosx", + "@platforms//os:macos": "macosx", + "@platforms//os:windows": "win", + "@platforms//os:linux": "linux", }), - # Following the convention in jax/tools/build_utils.py. # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. platform_tag = select({ - "//jaxlib/tools:macos_arm64": "arm64", - "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:arm64": "aarch64", + "@platforms//cpu:aarch64": "arm64", + "@platforms//cpu:arm64": "arm64", "@platforms//cpu:x86_64": "x86_64", }), + manylinux_compliance_tag = select({ + "@platforms//cpu:aarch64": "manylinux_2_17_aarch64", + "@platforms//cpu:arm64": "manylinux_2_17_aarch64", + "@platforms//cpu:x86_64": "manylinux_2_17_x86_64", + "//conditions:default": "", + }), ) jax_test_file_visibility = [] diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 48dc03cfb7d6..4f7c44c49727 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -14,10 +14,15 @@ # JAX is Autograd and XLA -load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") +load("@jax_wheel_version//:wheel_version.bzl", "WHEEL_VERSION") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@tsl//third_party/py:py_import.bzl", + "py_import", +) +load("@xla//xla/tsl:tsl.bzl", "if_cuda_libs") load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel") licenses(["notice"]) # Apache 2 @@ -141,48 +146,32 @@ py_binary( ], ) -selects.config_setting_group( - name = "macos", - match_any = [ - "@platforms//os:osx", - "@platforms//os:macos", - ], -) - -selects.config_setting_group( - name = "arm64", - match_any = [ - "@platforms//cpu:aarch64", - "@platforms//cpu:arm64", - ], +# Empty by default. Use `--//jaxlib/tools:git_hash=$(git rev-parse HEAD)` flag in +# bazel command to pass the git hash for nightly or release builds. +string_flag( + name = "git_hash", + build_setting_default = "", ) -selects.config_setting_group( - name = "macos_arm64", - match_all = [ - ":arm64", - ":macos", - ], +string_flag( + name = "output_path", + build_setting_default = "dist", ) -selects.config_setting_group( - name = "win_amd64", - match_all = [ - "@platforms//cpu:x86_64", - "@platforms//os:windows", - ], +string_flag( + name = "wheel_type", + build_setting_default = "snapshot", + values = ["snapshot", "nightly", "release"], ) string_flag( - name = "jaxlib_git_hash", + name = "custom_version_suffix", build_setting_default = "", ) -config_setting( - name = "jaxlib_git_hash_nightly_or_release", - flag_values = { - ":jaxlib_git_hash": "nightly", - }, +string_flag( + name = "build_date", + build_setting_default = "", ) jax_wheel( @@ -205,3 +194,99 @@ jax_wheel( platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", ) + +# py_binary( +# name = "rename_jaxlib_wheel", +# srcs = [ +# "rename_wheel.py", +# "//jaxlib:version", +# ], +# data = [ +# ":jaxlib_wheel", +# ], +# env = { +# "WHEEL_DIR": "jaxlib_wheel", +# "WHEEL_VERSION": WHEEL_VERSION, +# }, +# main = "rename_wheel.py", +# ) + +# py_binary( +# name = "rename_jax_cuda_plugin_wheel", +# srcs = [ +# "rename_wheel.py", +# "//jaxlib:version", +# ], +# data = [ +# ":jax_cuda_plugin_wheel", +# ], +# env = { +# "WHEEL_DIR": "jax_cuda_plugin_wheel", +# "WHEEL_VERSION": WHEEL_VERSION, +# }, +# main = "rename_wheel.py", +# ) + +# py_binary( +# name = "rename_jax_cuda_pjrt_wheel", +# srcs = [ +# "rename_wheel.py", +# "//jaxlib:version", +# ], +# data = [ +# ":jax_cuda_pjrt_wheel", +# ], +# env = { +# "WHEEL_DIR": "jax_cuda_pjrt_wheel", +# "WHEEL_VERSION": WHEEL_VERSION, +# }, +# main = "rename_wheel.py", +# ) + +py_import( + name = "jaxlib_py_import", + wheel = ":jaxlib_wheel", + deps = [ + "@pypi_ml_dtypes//:pkg", + "@pypi_numpy//:pkg", + "@pypi_scipy//:pkg", + ], +) + +py_import( + name = "jax_cuda_plugin_py_import", + cc_deps = if_cuda_libs([ + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + "@cuda_cudart//:cudart", + "@cuda_cudnn//:cudnn", + "@cuda_cufft//:cufft", + "@cuda_cupti//:cupti", + "@cuda_curand//:curand", + "@cuda_cusolver//:cusolver", + "@cuda_cusparse//:cusparse", + "@cuda_nccl//:nccl", + "@cuda_nvjitlink//:nvjitlink", + "@cuda_nvrtc//:nvrtc", + ]), + wheel = ":jax_cuda_plugin_wheel", +) + +py_import( + name = "jax_cuda_pjrt_py_import", + cc_deps = if_cuda_libs([ + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + "@cuda_cudart//:cudart", + "@cuda_cudnn//:cudnn", + "@cuda_cufft//:cufft", + "@cuda_cupti//:cupti", + "@cuda_curand//:curand", + "@cuda_cusolver//:cusolver", + "@cuda_cusparse//:cusparse", + "@cuda_nccl//:nccl", + "@cuda_nvjitlink//:nvjitlink", + "@cuda_nvrtc//:nvrtc", + ]), + wheel = ":jax_cuda_pjrt_wheel", +) diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 9a47c6ad5409..65400fbb0fa4 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -68,14 +68,15 @@ def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) + build_tag = build_utils.build_tag() with open(sources_path / "setup.cfg", "w") as f: f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} -""") +plat_name={plat_tag} +""" + (f"build_number={build_tag}\n" if build_tag else "")) def prepare_wheel_cuda( diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 08c2389c292a..40703cfc298a 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -73,17 +73,16 @@ def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) + build_tag = build_utils.build_tag() with open(sources_path / "setup.cfg", "w") as f: - f.write( - f"""[metadata] + f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} +plat_name={plat_tag} python-tag=py3 -""" - ) +""" + (f"build_number={build_tag}\n" if build_tag else "")) def prepare_cuda_plugin_wheel(sources_path: pathlib.Path, *, cpu, cuda_version): diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4db36fa0ea97..4b0ed2af4865 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -157,16 +157,15 @@ def verify_mac_libraries_dont_reference_chkstack(): def write_setup_cfg(sources_path, cpu): - tag = build_utils.platform_tag(cpu) + plat_tag = build_utils.platform_tag(cpu) + build_tag = build_utils.build_tag() with open(sources_path / "setup.cfg", "w") as f: - f.write( - f"""[metadata] + f.write(f"""[metadata] license_files = LICENSE.txt [bdist_wheel] -plat_name={tag} -""" - ) +plat_name={plat_tag} +""" + (f"build_number={build_tag}\n" if build_tag else "")) def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): diff --git a/jaxlib/tools/rename_wheel.py b/jaxlib/tools/rename_wheel.py new file mode 100644 index 000000000000..26c2682ebc0d --- /dev/null +++ b/jaxlib/tools/rename_wheel.py @@ -0,0 +1,47 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script for renaming wheels.""" + +import argparse +import os +import shutil +import jaxlib.version + +parser = argparse.ArgumentParser(description="Rename wheel script arguments") +parser.add_argument("--output-path", required=False, default="dist") +args = parser.parse_args() + +new_version = jaxlib.version._get_version_for_build() +git_hash_ind = new_version.find("+") if "+" in new_version else len(new_version) + +wheel_dir = os.path.join("jaxlib", "tools", os.getenv("WHEEL_DIR")) +old_file_name = "" +for f in os.listdir(wheel_dir): + if f.endswith(".whl"): + old_file_name = f + break +new_file_name = old_file_name.replace( + "{}-0".format(os.getenv("WHEEL_VERSION")), new_version[:git_hash_ind] +) + +workspace_dir = os.path.realpath(os.getenv("BUILD_WORKSPACE_DIRECTORY")) +new_dir = os.path.join(workspace_dir, args.output_path) +if not os.path.exists(new_dir): + os.mkdir(new_dir) +old_file = os.path.join(wheel_dir, old_file_name) +new_file = os.path.join(new_dir, new_file_name) +shutil.copyfile(old_file, new_file) + +print("Renamed wheel path: %s" % new_file)